Merge branch 'alpha' into feature/simple_stripe
This commit is contained in:
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
COPY web/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,7 +33,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = UnmarshalJson(requestBody, &v)
|
||||
err = Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -86,3 +87,25 @@ func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool)
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
func ApiError(c *gin.Context, err error) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
|
||||
func ApiSuccess(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func UnmarshalJson(data []byte, v any) error {
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,6 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
||||
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
@@ -39,11 +38,14 @@ func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||
pageInfo := &PageInfo{}
|
||||
err := c.BindQuery(pageInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 手动获取并处理每个参数
|
||||
if page, err := strconv.Atoi(c.Query("page")); err == nil {
|
||||
pageInfo.Page = page
|
||||
}
|
||||
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
@@ -56,7 +58,25 @@ func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
// 兼容
|
||||
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
}
|
||||
return pageInfo, nil
|
||||
|
||||
if pageInfo.PageSize > 100 {
|
||||
pageInfo.PageSize = 100
|
||||
}
|
||||
|
||||
return pageInfo
|
||||
}
|
||||
|
||||
@@ -32,16 +32,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
func StrToMap(str string) (map[string]interface{}, error) {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
err := Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func IsJsonStr(str string) bool {
|
||||
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||
var js []interface{}
|
||||
err := json.Unmarshal([]byte(str), &js)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func IsJsonArray(str string) bool {
|
||||
var js []interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func IsJsonObject(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
@@ -17,11 +17,20 @@ const (
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||
ContextKeyChannelKey ContextKey = "channel_key"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
|
||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
type MultiKeyMode string
|
||||
|
||||
const (
|
||||
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||
)
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -12,9 +11,12 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -409,26 +411,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
func UpdateChannelBalance(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
channel, err := model.CacheGetChannel(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "多密钥渠道不支持余额查询",
|
||||
})
|
||||
return
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -436,7 +436,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
||||
"message": "",
|
||||
"balance": balance,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func updateAllChannelsBalance() error {
|
||||
@@ -448,6 +447,9 @@ func updateAllChannelsBalance() error {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
continue // skip multi-key channels
|
||||
}
|
||||
// TODO: support Azure
|
||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||
// continue
|
||||
@@ -458,7 +460,7 @@ func updateAllChannelsBalance() error {
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
@@ -470,10 +472,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
||||
// TODO: make it async
|
||||
err := updateAllChannelsBalance()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -29,22 +30,43 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
type testResult struct {
|
||||
context *gin.Context
|
||||
localErr error
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
tik := time.Now()
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("suno channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("kling channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("jimeng channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -81,31 +103,49 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
localErr: err,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
c.Set("group", group)
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
if newAPIError != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: newAPIError,
|
||||
newAPIError: newAPIError,
|
||||
}
|
||||
}
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||
}
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
@@ -116,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||
}
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: respErr,
|
||||
newAPIError: respErr,
|
||||
}
|
||||
}
|
||||
if usageA == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
|
||||
@@ -187,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: nil,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
@@ -230,31 +306,38 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
func TestChannel(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
// go func() { _ = channel.SaveChannelInfo() }()
|
||||
// }
|
||||
//}()
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, testModel)
|
||||
result := testChannel(channel, testModel)
|
||||
if result.localErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": result.localErr.Error(),
|
||||
"time": 0.0,
|
||||
})
|
||||
return
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if err != nil {
|
||||
if result.newAPIError != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": result.newAPIError.Error(),
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
@@ -279,9 +362,9 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
testAllChannelsRunning = true
|
||||
testAllChannelsLock.Unlock()
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||
if getChannelErr != nil {
|
||||
return getChannelErr
|
||||
}
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
@@ -298,32 +381,34 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
result := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
shouldBanChannel := false
|
||||
|
||||
newAPIError := result.newAPIError
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
if newAPIError != nil {
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||
}
|
||||
|
||||
if milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
shouldBanChannel = true
|
||||
// 当错误检查通过,才检查响应时间
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
@@ -340,10 +425,7 @@ func testAllChannels(notify bool) error {
|
||||
func TestAllChannels(c *gin.Context) {
|
||||
err := testAllChannels(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -354,6 +436,10 @@ func TestAllChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
|
||||
@@ -53,14 +53,7 @@ func parseStatusFilter(statusParam string) int {
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
@@ -79,7 +72,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
@@ -126,7 +119,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
@@ -148,17 +141,12 @@ func GetAllChannels(c *gin.Context) {
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"type_counts": typeCounts,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -166,19 +154,13 @@ func GetAllChannels(c *gin.Context) {
|
||||
func FetchUpstreamModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,10 +177,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -230,10 +209,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -358,18 +334,12 @@ func SearchChannels(c *gin.Context) {
|
||||
func GetChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -380,17 +350,53 @@ func GetChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
if keys == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var keyArray []interface{}
|
||||
err := common.Unmarshal([]byte(keys), &keyArray)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||
}
|
||||
cleanKeys := make([]string, 0, len(keyArray))
|
||||
for _, key := range keyArray {
|
||||
var keyStr string
|
||||
switch v := key.(type) {
|
||||
case string:
|
||||
keyStr = strings.TrimSpace(v)
|
||||
default:
|
||||
bytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||||
}
|
||||
keyStr = string(bytes)
|
||||
}
|
||||
if keyStr != "" {
|
||||
cleanKeys = append(cleanKeys, keyStr)
|
||||
}
|
||||
}
|
||||
if len(cleanKeys) == 0 {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||
}
|
||||
return cleanKeys, nil
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
addChannelRequest := AddChannelRequest{}
|
||||
err := c.ShouldBindJSON(&addChannelRequest)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = channel.ValidateSettings()
|
||||
|
||||
err = addChannelRequest.Channel.ValidateSettings()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -398,56 +404,117 @@ func AddChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
|
||||
if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel cannot be empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the length of the model name
|
||||
for _, m := range addChannelRequest.Channel.GetModels() {
|
||||
if len(m) > 255 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("模型名称过长: %s", m),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||
})
|
||||
return
|
||||
}
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
keys = []string{channel.Key}
|
||||
}
|
||||
|
||||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||
keys := make([]string, 0)
|
||||
switch addChannelRequest.Mode {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
|
||||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||
} else {
|
||||
cleanKeys := make([]string, 0)
|
||||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
cleanKeys = append(cleanKeys, key)
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||
}
|
||||
case "single":
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的添加模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
localChannel := channel
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
// Validate the length of the model name
|
||||
models := strings.Split(localChannel.Models, ",")
|
||||
for _, model := range models {
|
||||
if len(model) > 255 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("模型名称过长: %s", model),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
channels = append(channels, localChannel)
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -462,10 +529,7 @@ func DeleteChannel(c *gin.Context) {
|
||||
channel := model.Channel{Id: id}
|
||||
err := channel.Delete()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -478,10 +542,7 @@ func DeleteChannel(c *gin.Context) {
|
||||
func DeleteDisabledChannel(c *gin.Context) {
|
||||
rows, err := model.DeleteDisabledChannel()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -514,10 +575,7 @@ func DisableTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.DisableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -539,10 +597,7 @@ func EnableTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.EnableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -571,10 +626,7 @@ func EditTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -601,10 +653,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
}
|
||||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -615,14 +664,16 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
channel := PatchChannel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = channel.ValidateSettings()
|
||||
@@ -641,20 +692,25 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
regionMap, err := common.StrToMap(channel.Other)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||
})
|
||||
return
|
||||
}
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -662,6 +718,19 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
|
||||
channel.ChannelInfo = originChannel.ChannelInfo
|
||||
|
||||
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -764,10 +833,7 @@ func BatchSetChannelTag(c *gin.Context) {
|
||||
}
|
||||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -818,3 +884,52 @@ func GetTagModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CopyChannel handles cloning an existing channel with its key.
|
||||
// POST /api/channel/copy/:id
|
||||
// Optional query params:
|
||||
// suffix - string appended to the original name (default "_复制")
|
||||
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
|
||||
func CopyChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := c.DefaultQuery("suffix", "_复制")
|
||||
resetBalance := true
|
||||
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
|
||||
if v, err := strconv.ParseBool(rbStr); err == nil {
|
||||
resetBalance = v
|
||||
}
|
||||
}
|
||||
|
||||
// fetch original channel with key
|
||||
origin, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// clone channel
|
||||
clone := *origin // shallow copy is sufficient as we will overwrite primitives
|
||||
clone.Id = 0 // let DB auto-generate
|
||||
clone.CreatedTime = common.GetTimestamp()
|
||||
clone.Name = origin.Name + suffix
|
||||
clone.TestTime = 0
|
||||
clone.ResponseTime = 0
|
||||
if resetBalance {
|
||||
clone.Balance = 0
|
||||
clone.UsedQuota = 0
|
||||
}
|
||||
|
||||
// insert
|
||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
|
||||
@@ -5,13 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
||||
}
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -13,8 +12,9 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
@@ -213,14 +213,7 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
// 解析其他查询参数
|
||||
queryParams := model.TaskQueryParams{
|
||||
@@ -230,7 +223,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
@@ -239,27 +232,13 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -269,7 +248,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
@@ -278,14 +257,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -217,10 +217,7 @@ func SendEmailVerification(c *gin.Context) {
|
||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -256,10 +253,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -294,10 +288,7 @@ func ResetPassword(c *gin.Context) {
|
||||
password := common.GenerateVerificationCode(12)
|
||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||
|
||||
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -160,10 +160,7 @@ func UpdateOption(c *gin.Context) {
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -3,45 +3,44 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -52,26 +51,32 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
//c.Set("token_name", "playground-"+group)
|
||||
tempToken := &model.Token{
|
||||
UserId: userId,
|
||||
Name: fmt.Sprintf("playground-%s", group),
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, err = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
|
||||
@@ -1,91 +1,51 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllRedemptions(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchRedemptions(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetRedemption(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
redemption, err := model.GetRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -100,10 +60,7 @@ func AddRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
@@ -165,10 +122,7 @@ func DeleteRedemption(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
err := model.DeleteRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -183,18 +137,12 @@ func UpdateRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if statusOnly == "" {
|
||||
@@ -212,10 +160,7 @@ func UpdateRedemption(c *gin.Context) {
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -229,16 +174,13 @@ func UpdateRedemption(c *gin.Context) {
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,14 +17,15 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
@@ -55,14 +56,14 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.Error.Type
|
||||
other["error_code"] = err.Error.Code
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -73,25 +74,25 @@ func Relay(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = relayRequest(c, relayMode, channel)
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -101,14 +102,14 @@ func Relay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,8 +128,7 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,25 +137,25 @@ func WssRelay(c *gin.Context) {
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -165,12 +165,12 @@ func WssRelay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,27 +179,25 @@ func RelayClaude(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -209,30 +207,30 @@ func RelayClaude(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
@@ -245,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||
if retryCount == 0 {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
autoBanInt := 1
|
||||
@@ -259,19 +257,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
if group == "auto" {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
return nil, newAPIError
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
@@ -310,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,9 +395,10 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||
if newAPIError != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
@@ -398,9 +406,9 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -17,6 +15,9 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func UpdateTaskBulk() {
|
||||
@@ -225,14 +226,7 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -247,30 +241,15 @@ func GetAllTask(c *gin.Context) {
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -286,17 +265,9 @@ func GetUserTask(c *gin.Context) {
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -1,46 +1,26 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
size, _ := strconv.Atoi(c.Query("size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if size <= 0 {
|
||||
size = common.ItemsPerPage
|
||||
} else if size > 100 {
|
||||
size = 100
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Get total count for pagination
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": tokens,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": size,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -50,10 +30,7 @@ func SearchTokens(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -68,18 +45,12 @@ func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -95,10 +66,7 @@ func GetTokenStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
expiredAt := token.ExpiredTime
|
||||
@@ -118,10 +86,7 @@ func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -156,10 +121,7 @@ func AddToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -174,10 +136,7 @@ func DeleteToken(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -193,10 +152,7 @@ func UpdateToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -208,10 +164,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if token.Status == common.TokenStatusEnabled {
|
||||
@@ -245,10 +198,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -275,10 +225,7 @@ func DeleteTokenBatch(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllQuotaDates(c *gin.Context) {
|
||||
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
|
||||
}
|
||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -188,10 +188,7 @@ func Register(c *gin.Context) {
|
||||
cleanUser.Email = user.Email
|
||||
}
|
||||
if err := cleanUser.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -247,81 +244,45 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
pageInfo, err := common.GetPageQuery(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "parse page query failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": pageInfo,
|
||||
})
|
||||
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
startIdx := (p - 1) * pageSize
|
||||
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -344,10 +305,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// get rand int 28-32
|
||||
@@ -372,10 +330,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -395,18 +350,12 @@ func TransferAffQuota(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tran := TransferAffQuotaRequest{}
|
||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||
@@ -427,10 +376,7 @@ func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
@@ -455,10 +401,7 @@ func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
@@ -479,10 +422,7 @@ func GetUserModels(c *gin.Context) {
|
||||
}
|
||||
user, err := model.GetUserCache(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
@@ -524,10 +464,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -550,10 +487,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
@@ -599,17 +533,11 @@ func UpdateSelf(c *gin.Context) {
|
||||
}
|
||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := cleanUser.Update(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -640,18 +568,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
||||
func DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -686,10 +608,7 @@ func DeleteSelf(c *gin.Context) {
|
||||
|
||||
err := model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -735,10 +654,7 @@ func CreateUser(c *gin.Context) {
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -848,10 +764,7 @@ func ManageUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
clearUser := model.User{
|
||||
@@ -883,20 +796,14 @@ func EmailBind(c *gin.Context) {
|
||||
}
|
||||
err := user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.Email = email
|
||||
// no need to check if this email already taken, because we have used verification code to check it
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -918,19 +825,13 @@ func TopUp(c *gin.Context) {
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1013,10 +914,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type wechatLoginResponse struct {
|
||||
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.WeChatId = wechatId
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -3,6 +3,7 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
@@ -228,7 +229,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
|
||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
|
||||
@@ -55,6 +55,7 @@ type GeneralOpenAIRequest struct {
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
@@ -65,8 +66,8 @@ type GeneralOpenAIRequest struct {
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.UnmarshalJson(data, &result)
|
||||
data, _ := common.Marshal(r)
|
||||
_ = common.Unmarshal(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
@@ -28,7 +31,7 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
@@ -201,7 +204,7 @@ type OpenAIResponsesResponse struct {
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) {
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
err = SetupContextForToken(c, token, parts...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return fmt.Errorf("普通用户不支持指定渠道")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
@@ -237,28 +239,47 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||
// playground chat completions
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
c.Set("channel_create_time", channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", channel.GetAutoBan())
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||
|
||||
key, index, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAzure:
|
||||
@@ -278,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||
|
||||
328
model/channel.go
328
model/channel.go
@@ -1,9 +1,15 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -36,8 +42,138 @@ type Channel struct {
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.Marshal(&c)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
bytesValue, _ := value.([]byte)
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
trimmed := strings.TrimSpace(channel.Key)
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
// Otherwise, fall back to splitting by newline
|
||||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||
return keys
|
||||
}
|
||||
|
||||
func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
// If not in multi-key mode, return the original key string directly.
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
return channel.Key, 0, nil
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
}
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
if statusList == nil {
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
if status, ok := statusList[idx]; ok {
|
||||
return status
|
||||
}
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
|
||||
// Collect indexes of enabled keys
|
||||
enabledIdx := make([]int, 0, len(keys))
|
||||
for i := range keys {
|
||||
if getStatus(i) == common.ChannelStatusEnabled {
|
||||
enabledIdx = append(enabledIdx, i)
|
||||
}
|
||||
}
|
||||
// If no specific status list or none enabled, fall back to first key
|
||||
if len(enabledIdx) == 0 {
|
||||
return keys[0], 0, nil
|
||||
}
|
||||
|
||||
switch channel.ChannelInfo.MultiKeyMode {
|
||||
case constant.MultiKeyModeRandom:
|
||||
// Randomly pick one enabled key
|
||||
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
|
||||
return keys[selectedIdx], selectedIdx, nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Use channel-specific lock to ensure thread-safe polling
|
||||
lock := getChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||||
}
|
||||
if !common.MemoryCacheEnabled {
|
||||
_ = channel.SaveChannelInfo()
|
||||
} else {
|
||||
// CacheUpdateChannel(channel)
|
||||
}
|
||||
}()
|
||||
// Start from the saved polling index and look for the next enabled key
|
||||
start := channelInfo.MultiKeyPollingIndex
|
||||
if start < 0 || start >= len(keys) {
|
||||
start = 0
|
||||
}
|
||||
for i := 0; i < len(keys); i++ {
|
||||
idx := (start + i) % len(keys)
|
||||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||
// update polling index for next call (point to the next position)
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||
return keys[idx], idx, nil
|
||||
}
|
||||
}
|
||||
// Fallback – should not happen, but return first enabled key
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
default:
|
||||
// Unknown mode, default to first enabled key (or original key string)
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) SaveChannelInfo() error {
|
||||
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -175,14 +311,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
channel := &Channel{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
err = DB.First(channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
err = DB.Omit("key").First(channel, "id = ?", id).Error
|
||||
}
|
||||
return &channel, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
@@ -266,6 +408,44 @@ func (channel *Channel) Insert() error {
|
||||
}
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
// If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
var keyStr string
|
||||
if channel.Key != "" {
|
||||
keyStr = channel.Key
|
||||
} else {
|
||||
// If key is not provided, read the existing key from the database
|
||||
if existing, err := GetChannelById(channel.Id, true); err == nil {
|
||||
keyStr = existing.Key
|
||||
}
|
||||
}
|
||||
// Parse the key list (supports newline separation or JSON array)
|
||||
keys := []string{}
|
||||
if keyStr != "" {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 { // fallback to newline split
|
||||
keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
|
||||
}
|
||||
}
|
||||
channel.ChannelInfo.MultiKeySize = len(keys)
|
||||
// Clean up status data that exceeds the new key count to prevent index out of range
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
for idx := range channel.ChannelInfo.MultiKeyStatusList {
|
||||
if idx >= channel.ChannelInfo.MultiKeySize {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
if err != nil {
|
||||
@@ -308,48 +488,128 @@ func (channel *Channel) Delete() error {
|
||||
|
||||
var channelStatusLock sync.Mutex
|
||||
|
||||
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
||||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||||
var channelPollingLocks sync.Map
|
||||
|
||||
// getChannelPollingLock returns or creates a mutex for the given channel ID
|
||||
func getChannelPollingLock(channelId int) *sync.Mutex {
|
||||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||||
return lock.(*sync.Mutex)
|
||||
}
|
||||
// Create new lock for this channel
|
||||
newLock := &sync.Mutex{}
|
||||
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
|
||||
return actual.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// CleanupChannelPollingLocks removes locks for channels that no longer exist
|
||||
// This is optional and can be called periodically to prevent memory leaks
|
||||
func CleanupChannelPollingLocks() {
|
||||
var activeChannelIds []int
|
||||
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
|
||||
|
||||
activeChannelSet := make(map[int]bool)
|
||||
for _, id := range activeChannelIds {
|
||||
activeChannelSet[id] = true
|
||||
}
|
||||
|
||||
channelPollingLocks.Range(func(key, value interface{}) bool {
|
||||
channelId := key.(int)
|
||||
if !activeChannelSet[channelId] {
|
||||
channelPollingLocks.Delete(channelId)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
var keyIndex int
|
||||
for i, key := range keys {
|
||||
if key == usingKey {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if status == common.ChannelStatusEnabled {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
} else {
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = "All keys are disabled"
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
defer channelStatusLock.Unlock()
|
||||
|
||||
channelCache, _ := CacheGetChannel(id)
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache != nil && channelCache.Status == status {
|
||||
channelCache, _ := CacheGetChannel(channelId)
|
||||
if channelCache == nil {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if channelCache == nil && status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
CacheUpdateChannelStatus(id, status)
|
||||
}
|
||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||
|
||||
shouldUpdateAbilities := false
|
||||
defer func() {
|
||||
if shouldUpdateAbilities {
|
||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
channel, err := GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
// find channel by id error, directly update status
|
||||
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
||||
if result.Error != nil {
|
||||
common.SysError("failed to update channel status: " + result.Error.Error())
|
||||
return false
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if channel.Status == status {
|
||||
return false
|
||||
}
|
||||
// find channel by id success, update status and other info
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
handlerMultiKeyUpdate(channel, usingKey, status)
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
} else {
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
err = channel.Save()
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
@@ -532,6 +792,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
_ = channel.Save() // 保存修改
|
||||
}
|
||||
}
|
||||
return setting
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var group2model2channels map[string]map[string][]int // enabled channel
|
||||
var channelsIDM map[int]*Channel // all channels include disabled
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
@@ -24,7 +24,7 @@ func InitChannelCache() {
|
||||
}
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
DB.Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
@@ -34,21 +34,22 @@ func InitChannelCache() {
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newChannelsIDM := make(map[int]*Channel)
|
||||
newGroup2model2channels := make(map[string]map[string][]int)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
newGroup2model2channels[group] = make(map[string][]int)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
newChannelsIDM[channel.Id] = channel
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue // skip disabled channels
|
||||
}
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
newGroup2model2channels[group][model] = make([]int, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -57,7 +58,7 @@ func InitChannelCache() {
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
@@ -65,7 +66,7 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelsIDM
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
@@ -128,16 +129,27 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
channelSyncLock.RUnlock()
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
if len(channels) == 1 {
|
||||
if channel, ok := channelsIDM[channels[0]]; ok {
|
||||
return channel, nil
|
||||
}
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channel := range channels {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
@@ -152,9 +164,13 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.ChannelInfo, nil
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
@@ -203,3 +243,20 @@ func CacheUpdateChannelStatus(id int, status int) {
|
||||
channel.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
|
||||
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
channelsIDM[channel.Id] = channel
|
||||
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
logs[i].ChannelName = ""
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
|
||||
@@ -57,7 +57,7 @@ func initCol() {
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -225,12 +225,6 @@ func InitLogDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
//if common.UsingMySQL {
|
||||
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
||||
//}
|
||||
common.SysLog("database migration started")
|
||||
err = migrateLOGDB()
|
||||
return err
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -36,7 +35,7 @@ func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
func (user *UserBase) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package relay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -12,7 +11,10 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -54,13 +56,13 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
@@ -73,7 +75,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -88,23 +90,23 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -112,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ type Adaptor interface {
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -99,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
@@ -109,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -4,15 +4,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
@@ -124,49 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,29 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
@@ -68,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -38,11 +39,11 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
@@ -53,11 +54,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
@@ -119,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
@@ -174,32 +175,29 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: "ali_error",
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -84,7 +85,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -3,19 +3,22 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
@@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
return awsModelID
|
||||
}
|
||||
|
||||
return requestModel, nil
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
for event := range stream.Events() {
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
case *types.UnknownUnionMember:
|
||||
case *bedrockruntimeTypes.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return wrapErr(errors.New("unknown response type")), nil
|
||||
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return wrapErr(errors.New("nil or unknown response type")), nil
|
||||
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -140,15 +141,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
err, usage = baiduStreamHandler(c, info, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
err, usage = baiduEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
err, usage = baiduHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
@@ -110,92 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error sending stream response: " + err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -203,32 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -92,11 +93,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -94,7 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -125,7 +126,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -517,22 +518,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
return true
|
||||
}
|
||||
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Code: "stream_response_error",
|
||||
Type: claudeResponse.Error.Type,
|
||||
Message: claudeResponse.Error.Message,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -593,15 +587,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||
if err != nil {
|
||||
@@ -617,21 +611,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJson(data, &claudeResponse)
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
@@ -652,7 +639,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = data
|
||||
@@ -662,11 +649,11 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -674,7 +661,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -94,20 +95,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
err, usage = cfStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
err, usage = cfHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = cfSTTHandler(c, resp, info)
|
||||
err, usage = cfSTTHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package cloudflare
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
@@ -86,16 +88,16 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
@@ -107,7 +109,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -115,16 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var cfResp CfAudioResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
audioResp := &dto.AudioResponse{
|
||||
@@ -133,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
jsonResponse, err := json.Marshal(audioResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -71,14 +72,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
usage, err = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
usage, err = cohereHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
@@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
@@ -164,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
createdTime := common.GetTimestamp()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
@@ -188,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
openaiResp.Id = cohereResp.ResponseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Model = info.UpstreamModelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
@@ -201,24 +203,24 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
|
||||
jsonResponse, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
@@ -237,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/common"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -95,11 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
|
||||
}
|
||||
|
||||
// DoResponse implements channel.Adaptor.
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = cozeChatStreamHandler(c, resp, info)
|
||||
usage, err = cozeChatStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cozeChatHandler(c, resp, info)
|
||||
usage, err = cozeChatHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -43,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
|
||||
return cozeRequest
|
||||
}
|
||||
|
||||
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
// convert coze response to openai response
|
||||
@@ -55,10 +56,10 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
response.Model = info.UpstreamModelName
|
||||
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if cozeResponse.Code != 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
// 从上下文获取 usage
|
||||
var usage dto.Usage
|
||||
@@ -85,16 +86,16 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
@@ -135,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
@@ -143,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -81,11 +82,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -96,11 +97,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = difyStreamHandler(c, resp, info)
|
||||
return difyStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = difyHandler(c, resp, info)
|
||||
return difyHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -209,7 +210,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
||||
return &response
|
||||
}
|
||||
|
||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var responseText string
|
||||
usage := &dto.Usage{}
|
||||
var nodeToken int
|
||||
@@ -247,20 +248,20 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var difyResponse DifyChatCompletionResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &difyResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: difyResponse.ConversationId,
|
||||
@@ -279,10 +280,10 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &difyResponse.MetaData.Usage
|
||||
c.Writer.Write(jsonResponse)
|
||||
return &difyResponse.MetaData.Usage, nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -168,30 +168,30 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if info.IsStream {
|
||||
return GeminiTextGenerationStreamHandler(c, resp, info)
|
||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
return GeminiTextGenerationHandler(c, resp, info)
|
||||
return GeminiTextGenerationHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return GeminiImageHandler(c, resp, info)
|
||||
return GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
|
||||
// check if the model is an embedding model
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return GeminiEmbeddingHandler(c, resp, info)
|
||||
return GeminiEmbeddingHandler(c, info, resp)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
return GeminiChatStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = GeminiChatHandler(c, resp, info)
|
||||
return GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
|
||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
||||
@@ -205,23 +205,23 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
// }
|
||||
//}
|
||||
|
||||
return
|
||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
|
||||
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
@@ -241,7 +241,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
@@ -253,7 +253,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage = &dto.Usage{
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
|
||||
@@ -8,18 +8,19 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -28,9 +29,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
@@ -51,9 +52,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := common.EncodeJson(geminiResponse)
|
||||
jsonResponse, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
@@ -61,7 +62,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@@ -792,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
return &response, isStop, hasImage
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
@@ -858,33 +860,25 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
}
|
||||
helper.Done(c)
|
||||
//resp.Body.Close()
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
@@ -908,25 +902,25 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
c.Writer.Write(jsonResponse)
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
@@ -947,16 +941,16 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
|
||||
// Google has not yet clarified how embedding models will be billed
|
||||
// refer to openai billing method to use input tokens billing
|
||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||
usage = &dto.Usage{
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||
openAIResponse.Usage = *usage
|
||||
|
||||
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/common_handler"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -73,11 +74,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -69,11 +70,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -84,11 +85,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = mokaEmbeddingHandler(c, resp)
|
||||
return mokaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
// err, usage = mokaHandler(c, resp)
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@ package mokaai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
|
||||
@@ -48,16 +50,16 @@ func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEm
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var baiduResponse dto.EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
// if baiduResponse.ErrorMsg != "" {
|
||||
// return &dto.OpenAIErrorWithStatusCode{
|
||||
@@ -69,12 +71,12 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
// }, nil
|
||||
// }
|
||||
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &fullTextResponse.Usage, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -74,14 +75,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
@@ -82,19 +84,19 @@ func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequ
|
||||
}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
|
||||
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
@@ -103,22 +105,22 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: promptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: promptTokens,
|
||||
PromptTokens: info.PromptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: model,
|
||||
Model: info.UpstreamModelName,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := json.Marshal(embeddingResponse)
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"one-api/relay/common_handler"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -34,9 +35,9 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if !strings.Contains(request.Model, "claude") {
|
||||
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
}
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
//}
|
||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -421,31 +422,31 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeRealtime:
|
||||
err, usage = OpenaiRealtimeHandler(c, info)
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
usage = OpenaiTTSHandler(c, resp, info)
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
||||
case relayconstant.RelayModeRerank:
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||
case relayconstant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
||||
usage, err = OaiResponsesStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = OaiResponsesHandler(c, resp, info)
|
||||
usage, err = OaiResponsesHandler(c, info, resp)
|
||||
}
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = OaiStreamHandler(c, resp, info)
|
||||
usage, err = OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info)
|
||||
usage, err = OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"one-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -104,10 +106,10 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
}
|
||||
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
@@ -177,26 +179,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
err = common.UnmarshalJson(responseBody, &simpleResponse)
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: *simpleResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
@@ -220,28 +219,28 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
if forceFormat {
|
||||
responseBody, err = common.EncodeJson(simpleResponse)
|
||||
responseBody, err = common.Marshal(simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||
claudeRespStr, err := common.EncodeJson(claudeResp)
|
||||
claudeRespStr, err := common.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return nil, &simpleResponse.Usage
|
||||
return &simpleResponse.Usage, nil
|
||||
}
|
||||
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
@@ -261,20 +260,20 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
return nil, usage
|
||||
return usage
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
@@ -328,9 +327,9 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
info.IsStream = true
|
||||
@@ -368,7 +367,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.UnmarshalJson(message, realtimeEvent)
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
@@ -428,7 +427,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.UnmarshalJson(message, realtimeEvent)
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
@@ -553,18 +552,18 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
||||
return err
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.UnmarshalJson(responseBody, &usageResp)
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
@@ -584,5 +583,5 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
|
||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||
}
|
||||
return nil, &usageResp.Usage
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
@@ -9,33 +9,27 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// read response body
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
err = common.UnmarshalJson(responseBody, &responsesResponse)
|
||||
err = common.Unmarshal(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if responsesResponse.Error != nil {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: responsesResponse.Error.Message,
|
||||
Type: "openai_error",
|
||||
Code: responsesResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
@@ -50,13 +44,13 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
for _, tool := range responsesResponse.Tools {
|
||||
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
|
||||
}
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
}
|
||||
|
||||
var usage = &dto.Usage{}
|
||||
@@ -99,5 +93,5 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
||||
}
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -70,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,14 +2,17 @@ package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
@@ -70,7 +73,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
||||
return &response
|
||||
}
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
|
||||
responseText := ""
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
@@ -121,42 +124,39 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -73,11 +74,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -76,20 +77,20 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = siliconflowRerankHandler(c, resp)
|
||||
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,24 +2,26 @@ package siliconflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||
@@ -33,10 +35,10 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
a.AppID = appId
|
||||
@@ -94,13 +95,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage, err = tencentStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
usage, err = tencentHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,17 +8,20 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
@@ -86,7 +89,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
return &response
|
||||
}
|
||||
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
@@ -126,38 +129,35 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return nil, responseText
|
||||
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: tencentSb.Response.Error.Message,
|
||||
Code: tencentSb.Response.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: tencentSb.Response.Error.Message,
|
||||
Code: tencentSb.Response.Error.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &fullTextResponse.Usage, nil
|
||||
}
|
||||
|
||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -208,19 +209,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
|
||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
}
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
@@ -228,12 +229,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -4,8 +4,11 @@ import "one-api/common"
|
||||
|
||||
func GetModelRegion(other string, localModelName string) string {
|
||||
// if other is json string
|
||||
if common.IsJsonStr(other) {
|
||||
m := common.StrToMap(other)
|
||||
if common.IsJsonObject(other) {
|
||||
m, err := common.StrToMap(other)
|
||||
if err != nil {
|
||||
return other // return original if parsing fails
|
||||
}
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -225,18 +226,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"one-api/relay/constant"
|
||||
@@ -56,6 +57,15 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-search") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
|
||||
request.Model = info.UpstreamModelName
|
||||
toMap := request.ToMap()
|
||||
toMap["search_parameters"] = map[string]any{
|
||||
"mode": "on",
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
@@ -95,15 +105,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = xAIStreamHandler(c, resp, info)
|
||||
usage, err = xAIStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = xAIHandler(c, resp, info)
|
||||
usage, err = xAIHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package xai
|
||||
|
||||
var ModelList = []string{
|
||||
// grok-4
|
||||
"grok-4", "grok-4-0709", "grok-4-0709-search",
|
||||
// grok-3
|
||||
"grok-3-beta", "grok-3-mini-beta",
|
||||
// grok-3 mini
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -34,7 +35,7 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
|
||||
return openAIResp
|
||||
}
|
||||
|
||||
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
usage := &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
@@ -74,30 +75,28 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
helper.Done(c)
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
var response *dto.SimpleResponse
|
||||
err = common.UnmarshalJson(responseBody, &response)
|
||||
err = common.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
||||
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
|
||||
// new body
|
||||
encodeJson, err := common.EncodeJson(response)
|
||||
encodeJson, err := common.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||
|
||||
return nil, &response.Usage
|
||||
return &response.Usage, nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -74,18 +74,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
if info.IsStream {
|
||||
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,18 +6,18 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
@@ -126,11 +126,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
helper.SetEventStreamHeaders(c)
|
||||
var usage dto.Usage
|
||||
@@ -153,14 +153,14 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
var usage dto.Usage
|
||||
var content string
|
||||
@@ -191,11 +191,11 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -77,11 +78,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = zhipuStreamHandler(c, resp)
|
||||
usage, err = zhipuStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = zhipuHandler(c, resp)
|
||||
usage, err = zhipuHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,18 +3,20 @@ package zhipu
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
@@ -150,7 +152,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage *dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
@@ -211,38 +213,33 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
}
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var zhipuResponse ZhipuResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Code: zhipuResponse.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
return &fullTextResponse.Usage, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -80,11 +81,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,10 +2,8 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -14,7 +12,10 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||
@@ -32,14 +33,14 @@ func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
@@ -48,35 +49,35 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
|
||||
if openaiErr != nil {
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
@@ -109,14 +110,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
@@ -124,26 +125,26 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
//log.Printf("usage: %v", usage)
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
|
||||
@@ -213,7 +213,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||
@@ -229,7 +229,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
@@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package common_handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -9,13 +8,15 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
@@ -24,9 +25,9 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var jinaResp dto.RerankResponse
|
||||
if info.ChannelType == constant.ChannelTypeXinference {
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
||||
err = common.Unmarshal(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||
for i, result := range xinRerankResponse.Results {
|
||||
@@ -59,14 +60,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
},
|
||||
}
|
||||
} else {
|
||||
err = common.UnmarshalJson(responseBody, &jinaResp)
|
||||
err = common.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, jinaResp)
|
||||
return nil, &jinaResp.Usage
|
||||
return &jinaResp.Usage, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -12,6 +11,9 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
@@ -32,24 +34,24 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
@@ -57,57 +59,57 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -104,11 +105,11 @@ func trimModelThinking(modelName string) string {
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateGeminiRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
@@ -120,14 +121,14 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -158,23 +159,23 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
@@ -195,7 +196,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -205,7 +206,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
@@ -215,10 +216,10 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,27 +4,29 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
// 检查是否已经设置过头部
|
||||
if _, exists := c.Get("event_stream_headers_set"); exists {
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
// 检查是否已经设置过头部
|
||||
if _, exists := c.Get("event_stream_headers_set"); exists {
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
}
|
||||
|
||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||
@@ -85,7 +87,7 @@ func ObjectData(c *gin.Context, object interface{}) error {
|
||||
if object == nil {
|
||||
return errors.New("object is nil")
|
||||
}
|
||||
jsonData, err := json.Marshal(object)
|
||||
jsonData, err := common.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
@@ -118,7 +120,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
return ws.WriteMessage(1, jsonData)
|
||||
}
|
||||
|
||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
|
||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
|
||||
errorObj := &dto.RealtimeEvent{
|
||||
Type: "error",
|
||||
EventId: GetLocalRealtimeID(c),
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -107,23 +108,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
var preConsumedQuota int
|
||||
var quota int
|
||||
@@ -132,13 +133,12 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
// modelRatio 16 = modelPrice $0.04
|
||||
// per 1 modelRatio = $0.04 / 16
|
||||
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
@@ -169,16 +169,16 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
}
|
||||
if userQuota-quota < 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
|
||||
}
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
@@ -186,14 +186,14 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
@@ -206,25 +206,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr := service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
|
||||
@@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
common.SysError("get_channel_null: " + err.Error())
|
||||
}
|
||||
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
||||
}
|
||||
}
|
||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -84,7 +85,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
@@ -92,8 +93,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
@@ -104,13 +104,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
@@ -122,23 +122,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
}
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newApiErr != nil {
|
||||
return newApiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newApiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
@@ -166,7 +166,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
@@ -174,32 +174,29 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = json.Marshal(reqMap)
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +210,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
@@ -222,18 +219,18 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newApiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newApiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||
@@ -281,16 +278,16 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
||||
}
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
@@ -314,11 +311,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
|
||||
@@ -2,15 +2,16 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
@@ -22,27 +23,27 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
||||
|
||||
if rerankRequest.Query == "" {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
@@ -50,32 +51,32 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
if common.DebugEnabled {
|
||||
@@ -83,7 +84,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
}
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
@@ -91,18 +92,18 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -46,11 +47,11 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateResponsesRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
||||
@@ -59,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -78,52 +79,52 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
}
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = json.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +137,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
@@ -145,18 +146,18 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user