Merge remote-tracking branch 'origin/multi_keys_channel' into alpha

# Conflicts:
#	web/src/components/table/LogsTable.js
#	web/src/i18n/locales/en.json
#	web/src/pages/Channel/EditChannel.js
This commit is contained in:
t0ng7u
2025-07-12 23:47:24 +08:00
98 changed files with 2261 additions and 1211 deletions

View File

@@ -32,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = UnmarshalJson(requestBody, &v)
err = Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this

View File

@@ -5,7 +5,7 @@ import (
"encoding/json"
)
func UnmarshalJson(data []byte, v any) error {
func Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}
@@ -17,6 +17,6 @@ func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
}
func EncodeJson(v any) ([]byte, error) {
func Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}

View File

@@ -32,16 +32,30 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
func StrToMap(str string) (map[string]interface{}, error) {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
err := Unmarshal([]byte(str), &m)
if err != nil {
return nil
return nil, err
}
return m
return m, nil
}
func IsJsonStr(str string) bool {
func StrToJsonArray(str string) ([]interface{}, error) {
var js []interface{}
err := json.Unmarshal([]byte(str), &js)
if err != nil {
return nil, err
}
return js, nil
}
func IsJsonArray(str string) bool {
var js []interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func IsJsonObject(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}

View File

@@ -17,11 +17,20 @@ const (
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */
ContextKeyBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyParamOverride ContextKey = "param_override"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
ContextKeyChannelBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyChannelParamOverride ContextKey = "param_override"
ContextKeyChannelOrganization ContextKey = "channel_organization"
ContextKeyChannelAutoBan ContextKey = "auto_ban"
ContextKeyChannelModelMapping ContextKey = "model_mapping"
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
ContextKeyChannelKey ContextKey = "channel_key"
/* user related keys */
ContextKeyUserId ContextKey = "id"

View File

@@ -0,0 +1,8 @@
package constant
type MultiKeyMode string
const (
MultiKeyModeRandom MultiKeyMode = "random" // 随机
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
)

View File

@@ -12,6 +12,7 @@ import (
"one-api/model"
"one-api/service"
"one-api/setting"
"one-api/types"
"strconv"
"time"
@@ -415,7 +416,7 @@ func UpdateChannelBalance(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(id, true)
channel, err := model.CacheGetChannel(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -423,6 +424,13 @@ func UpdateChannelBalance(c *gin.Context) {
})
return
}
if channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "多密钥渠道不支持余额查询",
})
return
}
balance, err := updateChannelBalance(channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -436,7 +444,6 @@ func UpdateChannelBalance(c *gin.Context) {
"message": "",
"balance": balance,
})
return
}
func updateAllChannelsBalance() error {
@@ -448,6 +455,9 @@ func updateAllChannelsBalance() error {
if channel.Status != common.ChannelStatusEnabled {
continue
}
if channel.ChannelInfo.IsMultiKey {
continue // skip multi-key channels
}
// TODO: support Azure
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue
@@ -458,7 +468,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
service.DisableChannel(channel.Id, channel.Name, "余额不足")
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
}
}
time.Sleep(common.RequestInterval)

View File

@@ -19,6 +19,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strconv"
"strings"
"sync"
@@ -29,22 +30,43 @@ import (
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
type testResult struct {
context *gin.Context
localErr error
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string) testResult {
tik := time.Now()
if channel.Type == constant.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
return testResult{
localErr: errors.New("midjourney channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported"), nil
return testResult{
localErr: errors.New("midjourney plus channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
return testResult{
localErr: errors.New("suno channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
return testResult{
localErr: errors.New("kling channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeJimeng {
return errors.New("jimeng channel test is not supported"), nil
return testResult{
localErr: errors.New("jimeng channel test is not supported"),
newAPIError: nil,
}
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -81,31 +103,49 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
cache, err := model.GetUserCache(1)
if err != nil {
return err, nil
return testResult{
localErr: err,
newAPIError: nil,
}
}
cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
middleware.SetupContextForSelectedChannel(c, channel, testModel)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
if newAPIError != nil {
return testResult{
context: c,
localErr: newAPIError,
newAPIError: newAPIError,
}
}
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info, nil)
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
}
}
testModel = info.UpstreamModelName
apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
return testResult{
context: c,
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
}
}
request := buildTestRequest(testModel)
@@ -116,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
}
}
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
}
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
}
}
}
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr
return testResult{
context: c,
localErr: respErr,
newAPIError: respErr,
}
}
if usageA == nil {
return errors.New("usage is nil"), nil
return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
}
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
}
}
info.PromptTokens = usage.PromptTokens
@@ -187,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
return testResult{
context: c,
localErr: nil,
newAPIError: nil,
}
}
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
@@ -236,7 +312,7 @@ func TestChannel(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(channelId, true)
channel, err := model.CacheGetChannel(channelId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -244,17 +320,30 @@ func TestChannel(c *gin.Context) {
})
return
}
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
// go func() { _ = channel.SaveChannelInfo() }()
// }
//}()
testModel := c.Query("model")
tik := time.Now()
err, _ = testChannel(channel, testModel)
result := testChannel(channel, testModel)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.localErr.Error(),
"time": 0.0,
})
return
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"message": result.newAPIError.Error(),
"time": consumedTime,
})
return
@@ -279,9 +368,9 @@ func testAllChannels(notify bool) error {
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
return err
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
if getChannelErr != nil {
return getChannelErr
}
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
@@ -298,32 +387,34 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiWithStatusErr := testChannel(channel, "")
result := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false
newAPIError := result.newAPIError
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
if newAPIError != nil {
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
}
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
shouldBanChannel = true
// 当错误检查通过,才检查响应时间
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
if milliseconds > disableThreshold {
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
shouldBanChannel = true
}
}
// disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error())
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
}
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
}
channel.UpdateResponseTime(milliseconds)

View File

@@ -380,9 +380,47 @@ func GetChannel(c *gin.Context) {
return
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.Unmarshal([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
var keyStr string
switch v := key.(type) {
case string:
keyStr = strings.TrimSpace(v)
default:
bytes, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
}
keyStr = string(bytes)
}
if keyStr != "" {
cleanKeys = append(cleanKeys, keyStr)
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -390,7 +428,8 @@ func AddChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
err = addChannelRequest.Channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -398,49 +437,113 @@ func AddChannel(c *gin.Context) {
})
return
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel cannot be empty",
})
return
}
// Validate the length of the model name
for _, m := range addChannelRequest.Channel.GetModels() {
if len(m) > 255 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("模型名称过长: %s", m),
})
return
}
}
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
if addChannelRequest.Channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
keys = []string{channel.Key}
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
case "single":
keys = []string{addChannelRequest.Channel.Key}
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的添加模式",
})
return
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := channel
localChannel := addChannelRequest.Channel
localChannel.Key = key
// Validate the length of the model name
models := strings.Split(localChannel.Models, ",")
for _, model := range models {
if len(model) > 255 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("模型名称过长: %s", model),
})
return
}
}
channels = append(channels, localChannel)
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
@@ -615,8 +718,13 @@ func DeleteChannelBatch(c *gin.Context) {
return
}
type PatchChannel struct {
model.Channel
MultiKeyMode *string `json:"multi_key_mode"`
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
channel := PatchChannel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -641,17 +749,34 @@ func UpdateChannel(c *gin.Context) {
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
originChannel, err := model.GetChannelById(channel.Id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if originChannel.ChannelInfo.IsMultiKey {
channel.ChannelInfo = originChannel.ChannelInfo
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
}
}
err = channel.Update()

View File

@@ -3,45 +3,44 @@ package controller
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/service"
"one-api/setting"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
var newAPIError *types.NewAPIError
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
if newAPIError != nil {
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
@@ -52,26 +51,32 @@ func Playground(c *gin.Context) {
group = userGroup
} else {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
userId := c.GetInt("id")
//c.Set("token_name", "playground-"+group)
tempToken := &model.Token{
UserId: userId,
Name: fmt.Sprintf("playground-%s", group),
Group: group,
}
_ = middleware.SetupContextForToken(c, tempToken)
_, err = getChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
return
}
userCache.WriteContext(c)

View File

@@ -17,14 +17,15 @@ import (
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
var err *types.NewAPIError
switch relayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c)
@@ -55,14 +56,14 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.Error.Type
other["error_code"] = err.Error.Code
other["error_type"] = err.ErrorType
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
}
return err
@@ -73,25 +74,25 @@ func Relay(c *gin.Context) {
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
newAPIError = err
break
}
openaiErr = relayRequest(c, relayMode, channel)
newAPIError = relayRequest(c, relayMode, channel)
if openaiErr == nil {
if newAPIError == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
@@ -101,14 +102,14 @@ func Relay(c *gin.Context) {
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
}
@@ -127,8 +128,7 @@ func WssRelay(c *gin.Context) {
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
helper.WssError(c, ws, openaiErr.Error)
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
return
}
@@ -137,25 +137,25 @@ func WssRelay(c *gin.Context) {
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
newAPIError = err
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
newAPIError = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
if newAPIError == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
@@ -165,12 +165,12 @@ func WssRelay(c *gin.Context) {
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
helper.WssError(c, ws, openaiErr.Error)
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
helper.WssError(c, ws, newAPIError.ToOpenAIError())
}
}
@@ -179,27 +179,25 @@ func RelayClaude(c *gin.Context) {
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var claudeErr *dto.ClaudeErrorWithStatusCode
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
newAPIError = err
break
}
claudeErr = claudeRequest(c, channel)
newAPIError = claudeRequest(c, channel)
if claudeErr == nil {
if newAPIError == nil {
return // 成功处理请求,直接返回
}
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
@@ -209,30 +207,30 @@ func RelayClaude(c *gin.Context) {
common.LogInfo(c, retryLogStr)
}
if claudeErr != nil {
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
c.JSON(claudeErr.StatusCode, gin.H{
if newAPIError != nil {
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": claudeErr.Error,
"error": newAPIError.ToClaudeError(),
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
@@ -245,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
@@ -259,19 +257,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
if group == "auto" {
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
}
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
return nil, newAPIError
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
if openaiErr == nil {
return false
}
if openaiErr.LocalError {
if types.IsChannelError(openaiErr) {
return true
}
if types.IsLocalError(openaiErr) {
return false
}
if retryTimes <= 0 {
@@ -310,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message)
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error())
}
}
@@ -388,9 +395,10 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
channel, newAPIError := getChannel(c, group, originalModel, i)
if newAPIError != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
@@ -398,9 +406,9 @@ func RelayTask(c *gin.Context) {
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}

View File

@@ -3,6 +3,7 @@ package dto
import (
"encoding/json"
"one-api/common"
"one-api/types"
)
type ClaudeMetadata struct {
@@ -228,7 +229,7 @@ type ClaudeResponse struct {
Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"`
Error *ClaudeError `json:"error,omitempty"`
Error *types.ClaudeError `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`

View File

@@ -1,5 +1,7 @@
package dto
import "one-api/types"
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
}
type GeneralErrorResponse struct {
Error OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Error types.OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`

View File

@@ -66,8 +66,8 @@ type GeneralOpenAIRequest struct {
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
data, _ := common.EncodeJson(r)
_ = common.UnmarshalJson(data, &result)
data, _ := common.Marshal(r)
_ = common.Unmarshal(data, &result)
return result
}

View File

@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"one-api/types"
)
type SimpleResponse struct {
Usage `json:"usage"`
@@ -28,7 +31,7 @@ type OpenAITextResponse struct {
Object string `json:"object"`
Created any `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error,omitempty"`
Error *types.OpenAIError `json:"error,omitempty"`
Usage `json:"usage"`
}
@@ -201,7 +204,7 @@ type OpenAIResponsesResponse struct {
Object string `json:"object"`
CreatedAt int `json:"created_at"`
Status string `json:"status"`
Error *OpenAIError `json:"error,omitempty"`
Error *types.OpenAIError `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"`

View File

@@ -1,5 +1,7 @@
package dto
import "one-api/types"
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *types.OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
}
type RealtimeResponse struct {

View File

@@ -1,6 +1,7 @@
package middleware
import (
"fmt"
"net/http"
"one-api/common"
"one-api/model"
@@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Set("token_name", token.Name)
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
}
c.Next()
}
}
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Set("token_name", token.Name)
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return fmt.Errorf("普通用户不支持指定渠道")
}
}
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/types"
"strconv"
"strings"
"time"
@@ -21,6 +22,7 @@ import (
type ModelRequest struct {
Model string `json:"model"`
Group string `json:"group,omitempty"`
}
func Distribute() func(c *gin.Context) {
@@ -237,28 +239,47 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("relay_mode", relayMode)
}
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
// playground chat completions
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
}
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
}
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
key, index, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
return newAPIError
}
if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
}
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case constant.ChannelTypeAzure:
@@ -278,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
case constant.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
}
return nil
}
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名

View File

@@ -1,9 +1,15 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/types"
"strings"
"sync"
@@ -36,8 +42,126 @@ type Channel struct {
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
// add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
}
type ChannelInfo struct {
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return common.Marshal(&c)
}
// Scan implements sql.Scanner interface
func (c *ChannelInfo) Scan(value interface{}) error {
bytesValue, _ := value.([]byte)
return common.Unmarshal(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
if channel.Key == "" {
return []string{}
}
// use \n to split keys
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
return keys
}
func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
// If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey {
return channel.Key, 0, nil
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
}
statusList := channel.ChannelInfo.MultiKeyStatusList
// helper to get key status, default to enabled when missing
getStatus := func(idx int) int {
if statusList == nil {
return common.ChannelStatusEnabled
}
if status, ok := statusList[idx]; ok {
return status
}
return common.ChannelStatusEnabled
}
// Collect indexes of enabled keys
enabledIdx := make([]int, 0, len(keys))
for i := range keys {
if getStatus(i) == common.ChannelStatusEnabled {
enabledIdx = append(enabledIdx, i)
}
}
// If no specific status list or none enabled, fall back to first key
if len(enabledIdx) == 0 {
return keys[0], 0, nil
}
switch channel.ChannelInfo.MultiKeyMode {
case constant.MultiKeyModeRandom:
// Randomly pick one enabled key
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
return keys[selectedIdx], selectedIdx, nil
case constant.MultiKeyModePolling:
// Use channel-specific lock to ensure thread-safe polling
lock := getChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
channelInfo, err := CacheGetChannelInfo(channel.Id)
if err != nil {
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
}
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
defer func() {
if common.DebugEnabled {
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
}
if !common.MemoryCacheEnabled {
_ = channel.SaveChannelInfo()
} else {
// CacheUpdateChannel(channel)
}
}()
// Start from the saved polling index and look for the next enabled key
start := channelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
}
for i := 0; i < len(keys); i++ {
idx := (start + i) % len(keys)
if getStatus(idx) == common.ChannelStatusEnabled {
// update polling index for next call (point to the next position)
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
return keys[idx], idx, nil
}
}
// Fallback should not happen, but return first enabled key
return keys[enabledIdx[0]], enabledIdx[0], nil
default:
// Unknown mode, default to first enabled key (or original key string)
return keys[enabledIdx[0]], enabledIdx[0], nil
}
}
func (channel *Channel) SaveChannelInfo() error {
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
}
func (channel *Channel) GetModels() []string {
@@ -175,14 +299,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
func GetChannelById(id int, selectAll bool) (*Channel, error) {
channel := Channel{Id: id}
channel := &Channel{Id: id}
var err error = nil
if selectAll {
err = DB.First(&channel, "id = ?", id).Error
err = DB.First(channel, "id = ?", id).Error
} else {
err = DB.Omit("key").First(&channel, "id = ?", id).Error
err = DB.Omit("key").First(channel, "id = ?", id).Error
}
return &channel, err
if err != nil {
return nil, err
}
if channel == nil {
return nil, errors.New("channel not found")
}
return channel, nil
}
func BatchInsertChannels(channels []Channel) error {
@@ -308,48 +438,128 @@ func (channel *Channel) Delete() error {
var channelStatusLock sync.Mutex
func UpdateChannelStatusById(id int, status int, reason string) bool {
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
var channelPollingLocks sync.Map
// getChannelPollingLock returns or creates a mutex for the given channel ID
func getChannelPollingLock(channelId int) *sync.Mutex {
if lock, exists := channelPollingLocks.Load(channelId); exists {
return lock.(*sync.Mutex)
}
// Create new lock for this channel
newLock := &sync.Mutex{}
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
return actual.(*sync.Mutex)
}
// CleanupChannelPollingLocks removes locks for channels that no longer exist
// This is optional and can be called periodically to prevent memory leaks
func CleanupChannelPollingLocks() {
var activeChannelIds []int
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
activeChannelSet := make(map[int]bool)
for _, id := range activeChannelIds {
activeChannelSet[id] = true
}
channelPollingLocks.Range(func(key, value interface{}) bool {
channelId := key.(int)
if !activeChannelSet[channelId] {
channelPollingLocks.Delete(channelId)
}
return true
})
}
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
if len(keys) == 0 {
channel.Status = status
} else {
var keyIndex int
for i, key := range keys {
if key == usingKey {
keyIndex = i
break
}
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if status == common.ChannelStatusEnabled {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
} else {
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
}
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
info := channel.GetOtherInfo()
info["status_reason"] = "All keys are disabled"
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
}
}
}
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
if common.MemoryCacheEnabled {
channelStatusLock.Lock()
defer channelStatusLock.Unlock()
channelCache, _ := CacheGetChannel(id)
// 如果缓存渠道存在,且状态已是目标状态,直接返回
if channelCache != nil && channelCache.Status == status {
channelCache, _ := CacheGetChannel(channelId)
if channelCache == nil {
return false
}
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
if channelCache == nil && status != common.ChannelStatusEnabled {
return false
if channelCache.ChannelInfo.IsMultiKey {
// 如果是多Key模式更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status)
//CacheUpdateChannel(channelCache)
//return true
} else {
// 如果缓存渠道存在,且状态已是目标状态,直接返回
if channelCache.Status == status {
return false
}
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
if status != common.ChannelStatusEnabled {
return false
}
CacheUpdateChannelStatus(channelId, status)
}
CacheUpdateChannelStatus(id, status)
}
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
shouldUpdateAbilities := false
defer func() {
if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
}
}
}()
channel, err := GetChannelById(channelId, true)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
return false
}
channel, err := GetChannelById(id, true)
if err != nil {
// find channel by id error, directly update status
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
common.SysError("failed to update channel status: " + result.Error.Error())
return false
}
if result.RowsAffected == 0 {
return false
}
} else {
if channel.Status == status {
return false
}
// find channel by id success, update status and other info
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
if channel.ChannelInfo.IsMultiKey {
beforeStatus := channel.Status
handlerMultiKeyUpdate(channel, usingKey, status)
if beforeStatus != channel.Status {
shouldUpdateAbilities = true
}
} else {
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
shouldUpdateAbilities = true
}
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
@@ -532,6 +742,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
channel.Setting = nil // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
}
}
return setting

View File

@@ -14,8 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var group2model2channels map[string]map[string][]int // enabled channel
var channelsIDM map[int]*Channel // all channels include disabled
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -24,7 +24,7 @@ func InitChannelCache() {
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
DB.Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
@@ -34,21 +34,22 @@ func InitChannelCache() {
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newChannelsIDM := make(map[int]*Channel)
newGroup2model2channels := make(map[string]map[string][]int)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
newGroup2model2channels[group] = make(map[string][]int)
}
for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
if channel.Status != common.ChannelStatusEnabled {
continue // skip disabled channels
}
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0)
newGroup2model2channels[group][model] = make([]int, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
}
}
}
@@ -57,7 +58,7 @@ func InitChannelCache() {
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
@@ -65,7 +66,7 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -128,16 +129,27 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
channelSyncLock.RUnlock()
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
}
uniquePriorities := make(map[int]bool)
for _, channel := range channels {
uniquePriorities[int(channel.GetPriority())] = true
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
uniquePriorities[int(channel.GetPriority())] = true
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
@@ -152,9 +164,13 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
// get the priority for the given retry number
var targetChannels []*Channel
for _, channel := range channels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
@@ -188,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) {
c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d已不存在", id))
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return c, nil
}
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !common.MemoryCacheEnabled {
channel, err := GetChannelById(id, true)
if err != nil {
return nil, err
}
return &channel.ChannelInfo, nil
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return &c.ChannelInfo, nil
}
func CacheUpdateChannelStatus(id int, status int) {
if !common.MemoryCacheEnabled {
return
@@ -203,3 +243,20 @@ func CacheUpdateChannelStatus(id int, status int) {
channel.Status = status
}
}
func CacheUpdateChannel(channel *Channel) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel == nil {
return
}
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}

View File

@@ -49,7 +49,7 @@ func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")

View File

@@ -57,7 +57,7 @@ func initCol() {
}
}
// log sql type and database type
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
@@ -225,12 +225,6 @@ func InitLogDB() (err error) {
if !common.IsMasterNode {
return nil
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateLOGDB()
return err

View File

@@ -1,7 +1,6 @@
package model
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
@@ -36,7 +35,7 @@ func (user *UserBase) WriteContext(c *gin.Context) {
func (user *UserBase) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
err := common.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}

View File

@@ -3,7 +3,6 @@ package relay
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
@@ -12,7 +11,10 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -54,13 +56,13 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return audioRequest, nil
}
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
promptTokens := 0
@@ -73,7 +75,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -88,23 +90,23 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -112,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -21,7 +22,7 @@ type Adaptor interface {
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -99,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
@@ -109,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return

View File

@@ -4,15 +4,17 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
@@ -124,49 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponse), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
c.Writer.Write(jsonResponse)
return nil, &dto.Usage{}
}

View File

@@ -7,7 +7,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -31,29 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
}
}
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
usage := dto.Usage{
@@ -68,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
jsonResponse, err := json.Marshal(rerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -8,9 +8,10 @@ import (
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -38,11 +39,11 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var fullTextResponse dto.OpenAIEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
@@ -53,11 +54,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
@@ -119,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -174,32 +175,29 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: "ali_error",
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -8,6 +8,7 @@ import (
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -84,7 +85,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {

View File

@@ -3,19 +3,22 @@ package aws
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
@@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) (string, error) {
func awsModelID(requestModel string) string {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
return awsModelID
}
return requestModel, nil
return requestModel
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
@@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
}
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
@@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
}
stream := awsResp.GetStream()
defer stream.Close()
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
@@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
for event := range stream.Events() {
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
case *types.UnknownUnionMember:
case *bedrockruntimeTypes.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return wrapErr(errors.New("unknown response type")), nil
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
default:
fmt.Println("union is nil or unknown type")
return wrapErr(errors.New("nil or unknown response type")), nil
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
}
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -140,15 +141,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
err, usage = baiduStreamHandler(c, info, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
err, usage = baiduEmbeddingHandler(c, info, resp)
default:
err, usage = baiduHandler(c, resp)
err, usage = baiduHandler(c, info, resp)
}
}
return

View File

@@ -1,21 +1,23 @@
package baidu
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@@ -110,92 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
usage := &dto.Usage{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
common.SysError("error sending stream response: " + err.Error())
}
return true
})
common.CloseResponseBodyGracefully(resp)
return nil, &usage
return nil, usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -203,32 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -92,11 +93,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -94,7 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {

View File

@@ -12,6 +12,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -125,7 +126,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
if textRequest.Reasoning != nil {
var reasoning openrouter.RequestReasoning
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
return nil, err
}
@@ -517,22 +518,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
return true
}
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Code: "stream_response_error",
Type: claudeResponse.Error.Type,
Message: claudeResponse.Error.Message,
},
StatusCode: http.StatusInternalServerError,
}
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
@@ -593,15 +587,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
var err *dto.OpenAIErrorWithStatusCode
var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
@@ -617,21 +611,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, claudeInfo.Usage
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJson(data, &claudeResponse)
err := common.Unmarshal(data, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Code: claudeResponse.Error.Type,
},
StatusCode: http.StatusInternalServerError,
}
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
}
if requestMode == RequestModeCompletion {
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
@@ -652,7 +639,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
case relaycommon.RelayFormatClaude:
responseData = data
@@ -662,11 +649,11 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
@@ -674,7 +661,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if common.DebugEnabled {
println("responseBody: ", string(responseBody))

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -94,20 +95,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = cfStreamHandler(c, resp, info)
err, usage = cfStreamHandler(c, info, resp)
} else {
err, usage = cfHandler(c, resp, info)
err, usage = cfHandler(c, info, resp)
}
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = cfSTTHandler(c, resp, info)
err, usage = cfSTTHandler(c, info, resp)
}
return
}

View File

@@ -3,7 +3,6 @@ package cloudflare
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -11,8 +10,11 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
}
}
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -86,16 +88,16 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return nil, usage
}
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
response.Model = info.UpstreamModelName
var responseText string
@@ -107,7 +109,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -115,16 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
return nil, usage
}
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
audioResp := &dto.AudioResponse{
@@ -133,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -71,14 +72,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
usage, err = cohereRerankHandler(c, resp, info)
} else {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
usage, err = cohereHandler(c, info, resp)
}
}
return

View File

@@ -3,7 +3,6 @@ package cohere
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -11,8 +10,11 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
@@ -164,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if usage.PromptTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
return usage, nil
}
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
@@ -188,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
openaiResp.Id = cohereResp.ResponseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion"
openaiResp.Model = modelName
openaiResp.Model = info.UpstreamModelName
openaiResp.Usage = usage
openaiResp.Choices = []dto.OpenAITextResponseChoice{
@@ -201,24 +203,24 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
_, _ = c.Writer.Write(jsonResponse)
return &usage, nil
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
@@ -237,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
return &usage, nil
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/common"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
@@ -95,11 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
}
// DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = cozeChatStreamHandler(c, resp, info)
usage, err = cozeChatStreamHandler(c, info, resp)
} else {
err, usage = cozeChatHandler(c, resp, info)
usage, err = cozeChatHandler(c, info, resp)
}
return
}

View File

@@ -12,6 +12,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -43,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
return cozeRequest
}
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
@@ -55,10 +56,10 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if cozeResponse.Code != 0 {
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
}
// 从上下文获取 usage
var usage dto.Usage
@@ -85,16 +86,16 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
return &usage, nil
}
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
@@ -135,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
}
if err := scanner.Err(); err != nil {
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
helper.Done(c)
@@ -143,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return nil, usage
return usage, nil
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -81,11 +82,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -96,11 +97,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = difyStreamHandler(c, resp, info)
return difyStreamHandler(c, info, resp)
} else {
err, usage = difyHandler(c, resp, info)
return difyHandler(c, info, resp)
}
return
}

View File

@@ -14,6 +14,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"os"
"strings"
@@ -209,7 +210,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
return &response
}
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
usage := &dto.Usage{}
var nodeToken int
@@ -247,20 +248,20 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return nil, usage
return usage, nil
}
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
@@ -279,10 +280,10 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &difyResponse.MetaData.Usage
c.Writer.Write(jsonResponse)
return &difyResponse.MetaData.Usage, nil
}

View File

@@ -11,8 +11,8 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -168,30 +168,30 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if info.IsStream {
return GeminiTextGenerationStreamHandler(c, resp, info)
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
return GeminiTextGenerationHandler(c, resp, info)
return GeminiTextGenerationHandler(c, info, resp)
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
return GeminiImageHandler(c, info, resp)
}
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return GeminiEmbeddingHandler(c, resp, info)
return GeminiEmbeddingHandler(c, info, resp)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
return GeminiChatStreamHandler(c, info, resp)
} else {
err, usage = GeminiChatHandler(c, resp, info)
return GeminiChatHandler(c, info, resp)
}
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
@@ -205,23 +205,23 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
// }
//}
return
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
@@ -241,7 +241,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
@@ -253,7 +253,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,

View File

@@ -8,18 +8,19 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if common.DebugEnabled {
@@ -28,9 +29,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
// 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse
err = common.UnmarshalJson(responseBody, &geminiResponse)
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// 计算使用量(基于 UsageMetadata
@@ -51,9 +52,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
}
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := common.EncodeJson(geminiResponse)
jsonResponse, err := common.Marshal(geminiResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
@@ -61,7 +62,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
return &usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int

View File

@@ -2,6 +2,7 @@ package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -12,6 +13,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strconv"
"strings"
"unicode/utf8"
@@ -792,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
return &response, isStop, hasImage
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
// responseText := ""
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
@@ -858,33 +860,25 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
}
helper.Done(c)
//resp.Body.Close()
return nil, usage
return usage, nil
}
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
err = common.UnmarshalJson(responseBody, &geminiResponse)
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
@@ -908,25 +902,25 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
c.Writer.Write(jsonResponse)
return &usage, nil
}
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
var geminiResponse GeminiEmbeddingResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
// convert to openai format response
@@ -947,16 +941,16 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
usage = &dto.Usage{
usage := &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
openAIResponse.Usage = *usage.(*dto.Usage)
openAIResponse.Usage = *usage
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)

View File

@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -73,11 +74,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = common_handler.RerankHandler(c, info, resp)
usage, err = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -69,11 +70,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -84,11 +85,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = mokaEmbeddingHandler(c, resp)
return mokaEmbeddingHandler(c, info, resp)
default:
// err, usage = mokaHandler(c, resp)

View File

@@ -2,12 +2,14 @@ package mokaai
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
@@ -48,16 +50,16 @@ func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEm
return &openAIEmbeddingResponse
}
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
@@ -69,12 +71,12 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -74,14 +75,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
usage, err = ollamaEmbeddingHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return

View File

@@ -1,15 +1,17 @@
package ollama
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
@@ -82,19 +84,19 @@ func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequ
}
}
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
@@ -103,22 +105,22 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: promptTokens,
TotalTokens: info.PromptTokens,
CompletionTokens: 0,
PromptTokens: promptTokens,
PromptTokens: info.PromptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: model,
Model: info.UpstreamModelName,
Usage: *usage,
}
doResponseBody, err := json.Marshal(embeddingResponse)
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
return nil, usage
return usage, nil
}
func flattenEmbeddings(embeddings [][]float64) []float64 {

View File

@@ -22,6 +22,7 @@ import (
"one-api/relay/common_handler"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/types"
"path/filepath"
"strings"
@@ -421,31 +422,31 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case relayconstant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
usage = OpenaiTTSHandler(c, resp, info)
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err, usage = OpenaiHandlerWithUsage(c, resp, info)
usage, err = OpenaiHandlerWithUsage(c, info, resp)
case relayconstant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, info, resp)
usage, err = common_handler.RerankHandler(c, info, resp)
case relayconstant.RelayModeResponses:
if info.IsStream {
err, usage = OaiResponsesStreamHandler(c, resp, info)
usage, err = OaiResponsesStreamHandler(c, info, resp)
} else {
err, usage = OaiResponsesHandler(c, resp, info)
usage, err = OaiResponsesHandler(c, info, resp)
}
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
usage, err = OaiStreamHandler(c, info, resp)
} else {
err, usage = OpenaiHandler(c, resp, info)
usage, err = OpenaiHandler(c, info, resp)
}
}
return

View File

@@ -17,6 +17,8 @@ import (
"path/filepath"
"strings"
"one-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -104,10 +106,10 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
return helper.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
defer common.CloseResponseBodyGracefully(resp)
@@ -177,26 +179,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
return nil, usage
return usage, nil
}
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
err = common.UnmarshalJson(responseBody, &simpleResponse)
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: *simpleResponse.Error,
StatusCode: resp.StatusCode,
}, nil
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
}
forceFormat := false
@@ -220,28 +219,28 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if forceFormat {
responseBody, err = common.EncodeJson(simpleResponse)
responseBody, err = common.Marshal(simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
} else {
break
}
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := common.EncodeJson(claudeResp)
claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
}
common.IOCopyBytesGracefully(c, resp, responseBody)
return nil, &simpleResponse.Usage
return &simpleResponse.Usage, nil
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
@@ -261,20 +260,20 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
common.LogError(c, err.Error())
}
return nil, usage
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -328,9 +327,9 @@ func countAudioTokens(c *gin.Context) (int, error) {
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
}
info.IsStream = true
@@ -368,7 +367,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
realtimeEvent := &dto.RealtimeEvent{}
err = common.UnmarshalJson(message, realtimeEvent)
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -428,7 +427,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = common.UnmarshalJson(message, realtimeEvent)
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -553,18 +552,18 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
return err
}
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
var usageResp dto.SimpleResponse
err = common.UnmarshalJson(responseBody, &usageResp)
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// 写入新的 response body
@@ -584,5 +583,5 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
return nil, &usageResp.Usage
return &usageResp.Usage, nil
}

View File

@@ -9,33 +9,27 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
err = common.UnmarshalJson(responseBody, &responsesResponse)
err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if responsesResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: responsesResponse.Error.Message,
Type: "openai_error",
Code: responsesResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
}
// 写入新的 response body
@@ -50,13 +44,13 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
for _, tool := range responsesResponse.Tools {
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
}
return nil, &usage
return &usage, nil
}
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
var usage = &dto.Usage{}
@@ -99,5 +93,5 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
}
}
return nil, usage
return usage, nil
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -70,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
usage, err = palmHandler(c, info, resp)
}
return
}

View File

@@ -2,14 +2,17 @@ package palm
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@@ -70,7 +73,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
return &response
}
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
responseText := ""
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
@@ -121,42 +124,39 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
return nil, responseText
}
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
return nil, types.WithOpenAIError(types.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
}, resp.StatusCode)
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
usage := dto.Usage{
PromptTokens: promptTokens,
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -73,11 +74,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -76,20 +77,20 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)
usage, err = siliconflowRerankHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -2,24 +2,26 @@ package siliconflow
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
@@ -33,10 +35,10 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -6,10 +6,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strconv"
"strings"
@@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := c.Request.Header.Get("Authorization")
apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
@@ -94,13 +95,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, err = tencentStreamHandler(c, info, resp)
} else {
err, usage = tencentHandler(c, resp)
usage, err = tencentHandler(c, info, resp)
}
return
}

View File

@@ -8,17 +8,20 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// https://cloud.tencent.com/document/product/1729/97732
@@ -86,7 +89,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -126,38 +129,35 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
common.CloseResponseBodyGracefully(resp)
return nil, responseText
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if tencentSb.Response.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
return nil, types.WithOpenAIError(types.OpenAIError{
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
}, resp.StatusCode)
}
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := json.Marshal(fullTextResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {

View File

@@ -14,6 +14,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -208,19 +209,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
} else {
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
}
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
}
} else {
switch a.RequestMode {
@@ -228,12 +229,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
err, usage = gemini.GeminiChatHandler(c, resp, info)
usage, err = gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return

View File

@@ -4,8 +4,11 @@ import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if common.IsJsonObject(other) {
m, err := common.StrToMap(other)
if err != nil {
return other // return original if parsing fails
}
if m[localModelName] != nil {
return m[localModelName].(string)
} else {

View File

@@ -13,6 +13,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"path/filepath"
"strings"
@@ -225,18 +226,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
}
return
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"one-api/relay/constant"
@@ -104,15 +105,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
default:
if info.IsStream {
err, usage = xAIStreamHandler(c, resp, info)
usage, err = xAIStreamHandler(c, info, resp)
} else {
err, usage = xAIHandler(c, resp, info)
usage, err = xAIHandler(c, info, resp)
}
}
return

View File

@@ -10,6 +10,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -34,7 +35,7 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
return openAIResp
}
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
usage := &dto.Usage{}
var responseTextBuilder strings.Builder
var toolCount int
@@ -74,30 +75,28 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
return nil, usage
return usage, nil
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
var response *dto.SimpleResponse
err = common.UnmarshalJson(responseBody, &response)
err = common.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
encodeJson, err := common.EncodeJson(response)
encodeJson, err := common.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return nil, nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, encodeJson)
return nil, &response.Usage
return &response.Usage, nil
}

View File

@@ -7,7 +7,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -74,18 +74,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return dummyResp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
}
if a.request == nil {
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
}
if info.IsStream {
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
}
return
}

View File

@@ -6,18 +6,18 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// https://console.xfyun.cn/services/cbm
@@ -126,11 +126,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
}
helper.SetEventStreamHeaders(c)
var usage dto.Usage
@@ -153,14 +153,14 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
return false
}
})
return nil, &usage
return &usage, nil
}
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
}
var usage dto.Usage
var content string
@@ -191,11 +191,11 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
return &usage, nil
}
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -77,11 +78,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
usage, err = zhipuStreamHandler(c, info, resp)
} else {
err, usage = zhipuHandler(c, resp)
usage, err = zhipuHandler(c, info, resp)
}
return
}

View File

@@ -3,18 +3,20 @@ package zhipu
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
@@ -150,7 +152,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
return &response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -211,38 +213,33 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
})
common.CloseResponseBodyGracefully(resp)
return nil, usage
return usage, nil
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if !zhipuResponse.Success {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
return nil, types.WithOpenAIError(types.OpenAIError{
Message: zhipuResponse.Msg,
Code: zhipuResponse.Code,
}, resp.StatusCode)
}
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
return &fullTextResponse.Usage, nil
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -80,11 +81,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -2,10 +2,8 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -14,7 +12,10 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
@@ -32,14 +33,14 @@ func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest
return textRequest, nil
}
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoClaude(c)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
if textRequest.Stream {
@@ -48,35 +49,35 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeCountTokenFailed)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return service.OpenAIErrorToClaudeError(openaiErr)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
@@ -109,14 +110,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := json.Marshal(convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody = bytes.NewBuffer(jsonData)
@@ -124,26 +125,26 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
//log.Printf("usage: %v", usage)
if openaiErr != nil {
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil

View File

@@ -213,7 +213,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
@@ -229,7 +229,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
@@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"),

View File

@@ -1,7 +1,6 @@
package common_handler
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -9,13 +8,15 @@ import (
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -24,9 +25,9 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var jinaResp dto.RerankResponse
if info.ChannelType == constant.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
@@ -59,14 +60,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
},
}
} else {
err = common.UnmarshalJson(responseBody, &jinaResp)
err = common.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}
c.Writer.Header().Set("Content-Type", "application/json")
c.JSON(http.StatusOK, jinaResp)
return nil, &jinaResp.Usage
return &jinaResp.Usage, nil
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
@@ -12,6 +11,9 @@ import (
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -32,24 +34,24 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
return nil
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
promptToken := getEmbeddingPromptToken(*embeddingRequest)
@@ -57,57 +59,57 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil

View File

@@ -14,6 +14,7 @@ import (
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -104,11 +105,11 @@ func trimModelThinking(modelName string) string {
return modelName
}
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
relayInfo := relaycommon.GenRelayInfoGemini(c)
@@ -120,14 +121,14 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
}
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -158,23 +159,23 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre consume quota
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
@@ -195,7 +196,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
requestBody, err := json.Marshal(req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
if common.DebugEnabled {
@@ -205,7 +206,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -215,10 +216,10 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}

View File

@@ -4,27 +4,29 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/types"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func SetEventStreamHeaders(c *gin.Context) {
// 检查是否已经设置过头部
if _, exists := c.Get("event_stream_headers_set"); exists {
return
}
// 检查是否已经设置过头部
if _, exists := c.Get("event_stream_headers_set"); exists {
return
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 设置标志,表示头部已经设置过
c.Set("event_stream_headers_set", true)
// 设置标志,表示头部已经设置过
c.Set("event_stream_headers_set", true)
}
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
@@ -85,7 +87,7 @@ func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
jsonData, err := json.Marshal(object)
jsonData, err := common.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
@@ -118,7 +120,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
return ws.WriteMessage(1, jsonData)
}
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -107,23 +108,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return imageRequest, nil
}
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
var preConsumedQuota int
var quota int
@@ -132,13 +133,12 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
var openaiErr *dto.OpenAIErrorWithStatusCode
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
@@ -169,16 +169,16 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeQueryDataError)
}
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
}
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
@@ -186,14 +186,14 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody = bytes.NewBuffer(jsonData)
}
@@ -206,25 +206,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
if usage.(*dto.Usage).TotalTokens == 0 {

View File

@@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
common.SysError("get_channel_null: " + err.Error())
}
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
}
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

View File

@@ -19,6 +19,7 @@ import (
"one-api/setting"
"one-api/setting/model_setting"
"one-api/setting/operation_setting"
"one-api/types"
"strings"
"time"
@@ -84,7 +85,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
return textRequest, nil
}
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfo(c)
@@ -92,8 +93,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
if textRequest.WebSearchOptions != nil {
@@ -104,13 +104,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
}
}
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
// 获取 promptTokens如果上下文中已经存在则直接使用
@@ -122,23 +122,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeCountTokenFailed)
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newApiErr != nil {
return newApiErr
}
defer func() {
if openaiErr != nil {
if newApiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
@@ -166,7 +166,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
@@ -174,32 +174,29 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
}
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
jsonData, err = common.Marshal(reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
}
}
@@ -213,7 +210,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -222,18 +219,18 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newApiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if newApiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
@@ -281,16 +278,16 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
}
// 预扣费并返回用户剩余配额
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
}
if userQuota <= 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
@@ -314,11 +311,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 {
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
}
}
return preConsumedQuota, userQuota, nil

View File

@@ -2,15 +2,16 @@ package relay
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -22,27 +23,27 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token
}
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
}
if len(rerankRequest.Documents) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
}
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
promptToken := getRerankPromptToken(*rerankRequest)
@@ -50,32 +51,32 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := json.Marshal(convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody := bytes.NewBuffer(jsonData)
if common.DebugEnabled {
@@ -83,7 +84,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -91,18 +92,18 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil

View File

@@ -14,6 +14,7 @@ import (
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -46,11 +47,11 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo
return inputTokens
}
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateResponsesRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
@@ -59,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
}
}
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -78,52 +79,52 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre consume quota
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
}
@@ -136,7 +137,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -145,18 +146,18 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {

View File

@@ -1,18 +1,18 @@
package relay
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
// get & validate textRequest 获取并验证文本请求
@@ -22,42 +22,31 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
//}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[relayInfo.OriginModelName] != "" {
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
// set upstream model name
//isModelMapped = true
}
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeModelPriceError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if openaiErr != nil {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
//var requestBody io.Reader
@@ -67,7 +56,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, nil)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
if resp != nil {
@@ -75,11 +64,11 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
defer relayInfo.TargetWs.Close()
}
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
if openaiErr != nil {
usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, priceData, "")

View File

@@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth())
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/model"
"one-api/setting/operation_setting"
"one-api/types"
"strings"
)
@@ -16,17 +17,17 @@ func formatNotifyType(channelId int, status int) string {
}
// disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
func DisableChannel(channelError types.ChannelError, reason string) {
success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
if success {
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelError.ChannelName, channelError.ChannelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelError.ChannelName, channelError.ChannelId, reason)
NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
}
}
func EnableChannel(channelId int, channelName string) {
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
func EnableChannel(channelId int, usingKey string, channelName string) {
success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
if success {
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
@@ -34,14 +35,17 @@ func EnableChannel(channelId int, channelName string) {
}
}
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if err.LocalError {
if types.IsChannelError(err) {
return true
}
if types.IsLocalError(err) {
return false
}
if err.StatusCode == http.StatusUnauthorized {
@@ -53,7 +57,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
}
}
switch err.Error.Code {
oaiErr := err.ToOpenAIError()
switch oaiErr.Code {
case "invalid_api_key":
return true
case "account_deactivated":
@@ -63,7 +68,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
case "pre_consume_token_quota_failed":
return true
}
switch err.Error.Type {
switch oaiErr.Type {
case "insufficient_quota":
return true
case "insufficient_user_quota":
@@ -77,23 +82,16 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
}
lowerMessage := strings.ToLower(err.Error.Message)
lowerMessage := strings.ToLower(err.Error())
search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
if search {
return true
}
return false
return search
}
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openaiWithStatusErr != nil {
if newAPIError != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {

View File

@@ -163,7 +163,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
encodeJson, _ := common.EncodeJson(mediaContents)
encodeJson, _ := common.Marshal(mediaContents)
oaiToolMessage.SetStringContent(string(encodeJson))
}
openAIMessages = append(openAIMessages, oaiToolMessage)

View File

@@ -2,11 +2,13 @@ package service
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/types"
"strconv"
"strings"
)
@@ -25,32 +27,32 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
}
}
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
openAIError := dto.OpenAIError{
Message: text,
Type: "new_api_error",
Code: code,
}
return &dto.OpenAIErrorWithStatusCode{
Error: openAIError,
StatusCode: statusCode,
}
}
func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
openaiErr := OpenAIErrorWrapper(err, code, statusCode)
openaiErr.LocalError = true
return openaiErr
}
//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
// text := err.Error()
// lowerText := strings.ToLower(text)
// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
// common.SysLog(fmt.Sprintf("error: %s", text))
// text = "请求上游地址失败"
// }
// }
// openAIError := dto.OpenAIError{
// Message: text,
// Type: "new_api_error",
// Code: code,
// }
// return &dto.OpenAIErrorWithStatusCode{
// Error: openAIError,
// StatusCode: statusCode,
// }
//}
//
//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
// openaiErr := OpenAIErrorWrapper(err, code, statusCode)
// openaiErr.LocalError = true
// return openaiErr
//}
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
@@ -77,43 +79,37 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr
}
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = &types.NewAPIError{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
common.CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
err = common.Unmarshal(responseBody, &errResponse)
if err != nil {
if showBodyWhenFail {
errWithStatusCode.Error.Message = string(responseBody)
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else {
errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
}
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
errWithStatusCode.Error = errResponse.Error
// General format error (OpenAI, Anthropic, Gemini, etc.)
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
} else {
errWithStatusCode.Error.Message = errResponse.ToMessage()
}
if errWithStatusCode.Error.Message == "" {
errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
}
return
}
func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
return
}
@@ -122,13 +118,13 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
if err != nil {
return
}
if openaiErr.StatusCode == http.StatusOK {
if newApiErr.StatusCode == http.StatusOK {
return
}
codeStr := strconv.Itoa(openaiErr.StatusCode)
codeStr := strconv.Itoa(newApiErr.StatusCode)
if _, ok := statusCodeMapping[codeStr]; ok {
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
openaiErr.StatusCode = intCode
newApiErr.StatusCode = intCode
}
}

View File

@@ -1,6 +1,8 @@
package service
import (
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -28,6 +30,11 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
}
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
return other
}

View File

@@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
auth := c.Request.Header.Get("Authorization")
auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
if auth != "" {
auth = strings.TrimPrefix(auth, "Bearer ")
req.Header.Set("mj-api-secret", auth)

21
types/channel_error.go Normal file
View File

@@ -0,0 +1,21 @@
package types
type ChannelError struct {
ChannelId int `json:"channel_id"`
ChannelType int `json:"channel_type"`
ChannelName string `json:"channel_name"`
IsMultiKey bool `json:"is_multi_key"`
AutoBan bool `json:"auto_ban"`
UsingKey string `json:"using_key"`
}
func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError {
return &ChannelError{
ChannelId: channelId,
ChannelType: channelType,
ChannelName: channelName,
IsMultiKey: isMultiKey,
AutoBan: autoBan,
UsingKey: usingKey,
}
}

195
types/error.go Normal file
View File

@@ -0,0 +1,195 @@
package types
import (
"errors"
"fmt"
"net/http"
"strings"
)
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type ClaudeError struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"`
}
type ErrorType string
const (
ErrorTypeNewAPIError ErrorType = "new_api_error"
ErrorTypeOpenAIError ErrorType = "openai_error"
ErrorTypeClaudeError ErrorType = "claude_error"
ErrorTypeMidjourneyError ErrorType = "midjourney_error"
ErrorTypeGeminiError ErrorType = "gemini_error"
ErrorTypeRerankError ErrorType = "rerank_error"
)
type ErrorCode string
const (
ErrorCodeInvalidRequest ErrorCode = "invalid_request"
ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected"
// new api error
ErrorCodeCountTokenFailed ErrorCode = "count_token_failed"
ErrorCodeModelPriceError ErrorCode = "model_price_error"
ErrorCodeInvalidApiType ErrorCode = "invalid_api_type"
ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed"
ErrorCodeDoRequestFailed ErrorCode = "do_request_failed"
ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed"
// channel error
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
// client request error
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
ErrorCodeAccessDenied ErrorCode = "access_denied"
// response error
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
ErrorCodeBadResponse ErrorCode = "bad_response"
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
// sql error
ErrorCodeQueryDataError ErrorCode = "query_data_error"
ErrorCodeUpdateDataError ErrorCode = "update_data_error"
// quota error
ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota"
ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed"
)
type NewAPIError struct {
Err error
RelayError any
ErrorType ErrorType
errorCode ErrorCode
StatusCode int
}
func (e *NewAPIError) GetErrorCode() ErrorCode {
if e == nil {
return ""
}
return e.errorCode
}
func (e *NewAPIError) Error() string {
return e.Err.Error()
}
func (e *NewAPIError) SetMessage(message string) {
e.Err = errors.New(message)
}
func (e *NewAPIError) ToOpenAIError() OpenAIError {
switch e.ErrorType {
case ErrorTypeOpenAIError:
return e.RelayError.(OpenAIError)
case ErrorTypeClaudeError:
claudeError := e.RelayError.(ClaudeError)
return OpenAIError{
Message: e.Error(),
Type: claudeError.Type,
Param: "",
Code: e.errorCode,
}
default:
return OpenAIError{
Message: e.Error(),
Type: string(e.ErrorType),
Param: "",
Code: e.errorCode,
}
}
}
func (e *NewAPIError) ToClaudeError() ClaudeError {
switch e.ErrorType {
case ErrorTypeOpenAIError:
openAIError := e.RelayError.(OpenAIError)
return ClaudeError{
Message: e.Error(),
Type: fmt.Sprintf("%v", openAIError.Code),
}
case ErrorTypeClaudeError:
return e.RelayError.(ClaudeError)
default:
return ClaudeError{
Message: e.Error(),
Type: string(e.ErrorType),
}
}
}
func NewError(err error, errorCode ErrorCode) *NewAPIError {
return &NewAPIError{
Err: err,
RelayError: nil,
ErrorType: ErrorTypeNewAPIError,
StatusCode: http.StatusInternalServerError,
errorCode: errorCode,
}
}
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
return &NewAPIError{
Err: err,
RelayError: nil,
ErrorType: ErrorTypeNewAPIError,
StatusCode: statusCode,
errorCode: errorCode,
}
}
func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
code, ok := openAIError.Code.(string)
if !ok {
code = fmt.Sprintf("%v", openAIError.Code)
}
return &NewAPIError{
RelayError: openAIError,
ErrorType: ErrorTypeOpenAIError,
StatusCode: statusCode,
Err: errors.New(openAIError.Message),
errorCode: ErrorCode(code),
}
}
func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
return &NewAPIError{
RelayError: claudeError,
ErrorType: ErrorTypeClaudeError,
StatusCode: statusCode,
Err: errors.New(claudeError.Message),
errorCode: ErrorCode(claudeError.Type),
}
}
func IsChannelError(err *NewAPIError) bool {
if err == nil {
return false
}
return strings.HasPrefix(string(err.errorCode), "channel:")
}
func IsLocalError(err *NewAPIError) bool {
if err == nil {
return false
}
return err.ErrorType == ErrorTypeNewAPIError
}

View File

@@ -42,18 +42,20 @@ import {
IconTreeTriangleDown,
IconSearch,
IconMore,
IconList, IconDescend2
} from '@douyinfe/semi-icons';
import { loadChannelModels, isMobile, copy } from '../../helpers';
import EditTagModal from '../../pages/Channel/EditTagModal.js';
import { useTranslation } from 'react-i18next';
import { useTableCompactMode } from '../../hooks/useTableCompactMode';
import { FaRandom } from 'react-icons/fa';
const ChannelsTable = () => {
const { t } = useTranslation();
let type2label = undefined;
const renderType = (type) => {
const renderType = (type, channelInfo = undefined) => {
if (!type2label) {
type2label = new Map();
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
@@ -61,11 +63,30 @@ const ChannelsTable = () => {
}
type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
}
let icon = getChannelIcon(type);
if (channelInfo?.is_multi_key) {
icon = (
channelInfo?.multi_key_mode === 'random' ? (
<div className="flex items-center gap-1">
<FaRandom className="text-blue-500" />
{icon}
</div>
) : (
<div className="flex items-center gap-1">
<IconDescend2 className="text-blue-500" />
{icon}
</div>
)
)
}
return (
<Tag
color={type2label[type]?.color}
shape='circle'
prefixIcon={getChannelIcon(type)}
prefixIcon={icon}
>
{type2label[type]?.label}
</Tag>
@@ -84,7 +105,19 @@ const ChannelsTable = () => {
);
};
const renderStatus = (status) => {
const renderStatus = (status, channelInfo = undefined) => {
if (channelInfo) {
if (channelInfo.is_multi_key) {
let keySize = channelInfo.multi_key_size;
let enabledKeySize = keySize;
if (channelInfo.multi_key_status_list) {
// multi_key_status_list is a map, key is key, value is status
// get multi_key_status_list length
enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length;
}
return renderMultiKeyStatus(status, keySize, enabledKeySize);
}
}
switch (status) {
case 1:
return (
@@ -113,6 +146,36 @@ const ChannelsTable = () => {
}
};
const renderMultiKeyStatus = (status, keySize, enabledKeySize) => {
switch (status) {
case 1:
return (
<Tag color='green' shape='circle'>
{t('已启用')} {enabledKeySize}/{keySize}
</Tag>
);
case 2:
return (
<Tag color='red' shape='circle'>
{t('已禁用')} {enabledKeySize}/{keySize}
</Tag>
);
case 3:
return (
<Tag color='yellow' shape='circle'>
{t('自动禁用')} {enabledKeySize}/{keySize}
</Tag>
);
default:
return (
<Tag color='grey' shape='circle'>
{t('未知状态')} {enabledKeySize}/{keySize}
</Tag>
);
}
}
const renderResponseTime = (responseTime) => {
let time = responseTime / 1000;
time = time.toFixed(2) + t(' 秒');
@@ -279,6 +342,11 @@ const ChannelsTable = () => {
dataIndex: 'type',
render: (text, record, index) => {
if (record.children === undefined) {
if (record.channel_info) {
if (record.channel_info.is_multi_key) {
return <>{renderType(text, record.channel_info)}</>;
}
}
return <>{renderType(text)}</>;
} else {
return <>{renderTagType()}</>;
@@ -302,12 +370,12 @@ const ChannelsTable = () => {
<Tooltip
content={t('原因:') + reason + t(',时间:') + timestamp2string(time)}
>
{renderStatus(text)}
{renderStatus(text, record.channel_info)}
</Tooltip>
</div>
);
} else {
return renderStatus(text);
return renderStatus(text, record.channel_info);
}
},
},
@@ -524,24 +592,70 @@ const ChannelsTable = () => {
/>
</SplitButtonGroup>
{record.status === 1 ? (
<Button
theme='light'
type='warning'
size="small"
onClick={() => manageChannel(record.id, 'disable', record)}
{record.channel_info?.is_multi_key ? (
<SplitButtonGroup
aria-label={t('多密钥渠道操作项目组')}
>
{t('禁用')}
</Button>
{
record.status === 1 ? (
<Button
theme='light'
type='warning'
size="small"
onClick={() => manageChannel(record.id, 'disable', record)}
>
{t('禁用')}
</Button>
) : (
<Button
theme='light'
type='secondary'
size="small"
onClick={() => manageChannel(record.id, 'enable', record)}
>
{t('启用')}
</Button>
)
}
<Dropdown
trigger='click'
position='bottomRight'
menu={[
{
node: 'item',
name: t('启用全部密钥'),
onClick: () => manageChannel(record.id, 'enable_all', record),
}
]}
>
<Button
theme='light'
type='secondary'
size="small"
icon={<IconTreeTriangleDown />}
/>
</Dropdown>
</SplitButtonGroup>
) : (
<Button
theme='light'
type='secondary'
size="small"
onClick={() => manageChannel(record.id, 'enable', record)}
>
{t('启用')}
</Button>
record.status === 1 ? (
<Button
theme='light'
type='warning'
size="small"
onClick={() => manageChannel(record.id, 'disable', record)}
>
{t('禁用')}
</Button>
) : (
<Button
theme='light'
type='secondary'
size="small"
onClick={() => manageChannel(record.id, 'enable', record)}
>
{t('启用')}
</Button>
)
)}
<Button
@@ -951,6 +1065,11 @@ const ChannelsTable = () => {
}
res = await API.put('/api/channel/', data);
break;
case 'enable_all':
data.channel_info = record.channel_info;
data.channel_info.multi_key_status_list = {};
res = await API.put('/api/channel/', data);
break;
}
const { success, message } = res.data;
if (success) {
@@ -1882,7 +2001,6 @@ const ChannelsTable = () => {
placeholder={t('请输入标签名称')}
value={batchSetTagValue}
onChange={(v) => setBatchSetTagValue(v)}
size='large'
/>
<div className="mt-4">
<Typography.Text type='secondary'>

View File

@@ -20,7 +20,7 @@ import {
renderQuota,
stringToColor,
getLogOther,
renderModelTag,
renderModelTag
} from '../../helpers';
import {
@@ -356,21 +356,43 @@ const LogsTable = () => {
dataIndex: 'channel',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
let isMultiKey = false
let multiKeyIndex = -1;
let other = getLogOther(record.other);
if (other?.admin_info) {
let adminInfo = other.admin_info;
if (adminInfo?.is_multi_key) {
isMultiKey = true;
multiKeyIndex = adminInfo.multi_key_index;
}
}
return isAdminUser ? (
record.type === 0 || record.type === 2 || record.type === 5 ? (
<div>
<>
{
<Tooltip content={record.channel_name || '[未知]'}>
<Tag
color={colors[parseInt(text) % colors.length]}
shape='circle'
>
{' '}
{text}{' '}
</Tag>
<Space>
<Tag
color={colors[parseInt(text) % colors.length]}
shape='circle'
>
{text}
</Tag>
{
isMultiKey && (
<Tag
color={'white'}
shape='circle'
>
{multiKeyIndex}
</Tag>
)
}
</Space>
</Tooltip>
}
</div>
</>
) : (
<></>
)

View File

@@ -87,6 +87,7 @@ export const buildApiPayload = (messages, systemPrompt, inputs, parameterEnabled
const payload = {
model: inputs.model,
group: inputs.group,
messages: processedMessages,
group: inputs.group,
stream: inputs.stream,

View File

@@ -1763,5 +1763,14 @@
"生成数量必须大于0": "Generation quantity must be greater than 0",
"可用端点类型": "Supported endpoint types",
"未登录,使用默认分组倍率:": "Not logged in, using default group ratio: ",
"该服务器地址将影响支付回调地址以及默认首页展示的地址,请确保正确配置": "This server address will affect the payment callback address and the address displayed on the default homepage, please ensure correct configuration"
"该服务器地址将影响支付回调地址以及默认首页展示的地址,请确保正确配置": "This server address will affect the payment callback address and the address displayed on the default homepage, please ensure correct configuration",
"密钥聚合模式": "Key aggregation mode",
"随机": "Random",
"轮询": "Polling",
"密钥文件 (.json)": "Key file (.json)",
"点击上传文件或拖拽文件到这里": "Click to upload file or drag and drop file here",
"仅支持 JSON 文件,支持多文件": "Only JSON files are supported, multiple files are supported",
"请上传密钥文件": "Please upload the key file",
"请填写部署地区": "Please fill in the deployment region",
"请输入部署地区例如us-central1\n支持使用模型映射格式\n{\n \"default\": \"us-central1\",\n \"claude-3-5-sonnet-20240620\": \"europe-west1\"\n}": "Please enter the deployment region, for example: us-central1\nSupports using model mapping format\n{\n \"default\": \"us-central1\",\n \"claude-3-5-sonnet-20240620\": \"europe-west1\"\n}"
}

View File

@@ -26,6 +26,7 @@ import {
Form,
Row,
Col,
Upload,
} from '@douyinfe/semi-ui';
import { getChannelModels, copy, getChannelIcon } from '../../helpers';
import {
@@ -35,6 +36,7 @@ import {
IconSetting,
IconCode,
IconGlobe,
IconBolt,
} from '@douyinfe/semi-icons';
const { Text, Title } = Typography;
@@ -100,10 +102,12 @@ const EditChannel = (props) => {
priority: 0,
weight: 0,
tag: '',
multi_key_mode: 'random',
};
const [batch, setBatch] = useState(false);
const [multiToSingle, setMultiToSingle] = useState(false);
const [multiKeyMode, setMultiKeyMode] = useState('random');
const [autoBan, setAutoBan] = useState(true);
// const [autoBan, setAutoBan] = useState(true);
const [inputs, setInputs] = useState(originInputs);
const [originModelOptions, setOriginModelOptions] = useState([]);
const [modelOptions, setModelOptions] = useState([]);
@@ -114,6 +118,10 @@ const EditChannel = (props) => {
const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const formApiRef = useRef(null);
const [vertexKeys, setVertexKeys] = useState([]);
const [vertexFileList, setVertexFileList] = useState([]);
const vertexErroredNames = useRef(new Set()); // 避免重复报错
const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false);
const getInitValues = () => ({ ...originInputs });
const handleInputChange = (name, value) => {
if (formApiRef.current) {
@@ -211,6 +219,19 @@ const EditChannel = (props) => {
2,
);
}
const chInfo = data.channel_info || {};
const isMulti = chInfo.is_multi_key === true;
setIsMultiKeyChannel(isMulti);
if (isMulti) {
setBatch(true);
setMultiToSingle(true);
const modeVal = chInfo.multi_key_mode || 'random';
setMultiKeyMode(modeVal);
data.multi_key_mode = modeVal;
} else {
setBatch(false);
setMultiToSingle(false);
}
setInputs(data);
if (formApiRef.current) {
formApiRef.current.setValues(data);
@@ -381,10 +402,76 @@ const EditChannel = (props) => {
}
}, [props.visible, channelId]);
const handleVertexUploadChange = ({ fileList }) => {
(async () => {
const validFiles = [];
const keys = [];
const errorNames = [];
for (const item of fileList) {
const fileObj = item.fileInstance;
if (!fileObj) continue;
try {
const txt = await fileObj.text();
keys.push(JSON.parse(txt));
validFiles.push(item); // 仅合法文件加入列表
} catch (err) {
if (!vertexErroredNames.current.has(item.name)) {
errorNames.push(item.name);
vertexErroredNames.current.add(item.name);
}
}
}
setVertexKeys(keys);
setVertexFileList(validFiles);
if (formApiRef.current) {
formApiRef.current.setValue('vertex_files', validFiles);
}
setInputs((prev) => ({ ...prev, vertex_files: validFiles }));
if (errorNames.length > 0) {
showError(t('以下文件解析失败,已忽略:{{list}}', { list: errorNames.join(', ') }));
}
})();
};
const submit = async () => {
const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
let localInputs = { ...formValues };
if (localInputs.type === 41) {
let keys = vertexKeys;
if (keys.length === 0) {
// 确保提交时也能解析,避免因异步延迟导致 keys 为空
try {
const parsed = await Promise.all(
vertexFileList.map(async (item) => {
const fileObj = item.fileInstance;
if (!fileObj) return null;
const txt = await fileObj.text();
return JSON.parse(txt);
})
);
keys = parsed.filter(Boolean);
} catch (err) {
showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
return;
}
}
if (keys.length === 0) {
showInfo(t('请上传密钥文件!'));
return;
}
if (batch) {
localInputs.key = JSON.stringify(keys);
} else {
localInputs.key = JSON.stringify(keys[0]);
}
}
delete localInputs.vertex_files;
if (!isEdit && (!localInputs.name || !localInputs.key)) {
showInfo(t('请填写渠道名称和渠道密钥!'));
return;
@@ -410,13 +497,23 @@ const EditChannel = (props) => {
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
localInputs.models = localInputs.models.join(',');
localInputs.group = (localInputs.groups || []).join(',');
let mode = 'single';
if (batch) {
mode = multiToSingle ? 'multi_to_single' : 'batch';
}
if (isEdit) {
res = await API.put(`/api/channel/`, {
...localInputs,
id: parseInt(channelId),
});
} else {
res = await API.post(`/api/channel/`, localInputs);
res = await API.post(`/api/channel/`, {
mode: mode,
multi_key_mode: mode === 'multi_to_single' ? multiKeyMode : undefined,
channel: localInputs,
});
}
const { success, message } = res.data;
if (success) {
@@ -469,9 +566,31 @@ const EditChannel = (props) => {
}
};
const batchAllowed = !isEdit && inputs.type !== 41;
const batchAllowed = !isEdit || isMultiKeyChannel;
const batchExtra = batchAllowed ? (
<Checkbox checked={batch} onChange={() => setBatch(!batch)}>{t('批量创建')}</Checkbox>
<Space>
<Checkbox disabled={isEdit} checked={batch} onChange={() => {
setBatch(!batch);
if (batch) {
setMultiToSingle(false);
setMultiKeyMode('random');
}
}}>{t('批量创建')}</Checkbox>
{batch && (
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
setMultiToSingle(prev => !prev);
setInputs(prev => {
const newInputs = { ...prev };
if (!multiToSingle) {
newInputs.multi_key_mode = multiKeyMode;
} else {
delete newInputs.multi_key_mode;
}
return newInputs;
});
}}>{t('密钥聚合模式')}</Checkbox>
)}
</Space>
) : null;
const channelOptionList = useMemo(
@@ -571,52 +690,93 @@ const EditChannel = (props) => {
/>
{batch ? (
<Form.TextArea
field='key'
label={t('密钥')}
placeholder={t('请输入密钥,一行一个')}
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
autosize={{ minRows: 6, maxRows: 6 }}
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
extraText={batchExtra}
/>
inputs.type === 41 ? (
<Form.Upload
field='vertex_files'
label={t('密钥文件 (.json)')}
accept='.json'
multiple
draggable
dragIcon={<IconBolt />}
dragMainText={t('点击上传文件或拖拽文件到这里')}
dragSubText={t('仅支持 JSON 文件,支持多文件')}
style={{ marginTop: 10 }}
uploadTrigger='custom'
beforeUpload={() => false}
onChange={handleVertexUploadChange}
fileList={vertexFileList}
rules={isEdit ? [] : [{ required: true, message: t('请上传密钥文件') }]}
extraText={batchExtra}
/>
) : (
<Form.TextArea
field='key'
label={t('密钥')}
placeholder={t('请输入密钥,一行一个')}
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
autosize
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
extraText={batchExtra}
showClear
/>
)
) : (
<>
{inputs.type === 41 ? (
<Form.TextArea
field='key'
label={t('密钥')}
placeholder={
'{\n' +
' "type": "service_account",\n' +
' "project_id": "abc-bcd-123-456",\n' +
' "private_key_id": "123xxxxx456",\n' +
' "private_key": "-----BEGIN PRIVATE KEY-----xxxx\n' +
' "client_email": "xxx@developer.gserviceaccount.com",\n' +
' "client_id": "111222333",\n' +
' "auth_uri": "https://accounts.google.com/o/oauth2/auth",\n' +
' "token_uri": "https://oauth2.googleapis.com/token",\n' +
' "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",\n' +
' "client_x509_cert_url": "https://xxxxx.gserviceaccount.com",\n' +
' "universe_domain": "googleapis.com"\n' +
'}'
}
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
autosize={{ minRows: 10 }}
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
<Form.Upload
field='vertex_files'
label={t('密钥文件 (.json)')}
accept='.json'
draggable
dragIcon={<IconBolt />}
dragMainText={t('点击上传文件或拖拽文件到这里')}
dragSubText={t('仅支持 JSON 文件')}
style={{ marginTop: 10 }}
uploadTrigger='custom'
beforeUpload={() => false}
onChange={handleVertexUploadChange}
fileList={vertexFileList}
rules={isEdit ? [] : [{ required: true, message: t('请上传密钥文件') }]}
extraText={batchExtra}
/>
) : (
<Form.Input
field='key'
label={t('密钥')}
label={isEdit ? t('密钥(编辑模式下,保存的密钥不会显示)') : t('密钥')}
placeholder={t(type2secretPrompt(inputs.type))}
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
extraText={batchExtra}
showClear
/>
)}
</>
)}
{batch && multiToSingle && (
<>
<Form.Select
field='multi_key_mode'
label={t('密钥聚合模式')}
placeholder={t('请选择多密钥使用策略')}
optionList={[
{ label: t('随机'), value: 'random' },
{ label: t('轮询'), value: 'polling' },
]}
style={{ width: '100%' }}
value={inputs.multi_key_mode || 'random'}
onChange={(value) => {
setMultiKeyMode(value);
handleInputChange('multi_key_mode', value);
}}
/>
{inputs.multi_key_mode === 'polling' && (
<Banner
type='warning'
description={t('轮询模式必须搭配Redis和内存缓存功能使用否则性能将大幅降低并且无法实现轮询功能')}
className='!rounded-lg mt-2'
/>
)}
</>
@@ -639,8 +799,9 @@ const EditChannel = (props) => {
placeholder={t(
'请输入部署地区例如us-central1\n支持使用模型映射格式\n{\n "default": "us-central1",\n "claude-3-5-sonnet-20240620": "europe-west1"\n}'
)}
autosize={{ minRows: 2 }}
autosize
onChange={(value) => handleInputChange('other', value)}
rules={[{ required: true, message: t('请填写部署地区') }]}
extraText={
<Text
className="!text-semi-color-primary cursor-pointer"
@@ -649,6 +810,7 @@ const EditChannel = (props) => {
{t('填入模板')}
</Text>
}
showClear
/>
)}