feat: 修复重试后请求结构混乱,修复rerank端点无法使用
This commit is contained in:
21
common/copy.go
Normal file
21
common/copy.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/antlabs/pcopy"
|
||||
)
|
||||
|
||||
func DeepCopy[T any](src *T) (*T, error) {
|
||||
if src == nil {
|
||||
return nil, fmt.Errorf("copy source cannot be nil")
|
||||
}
|
||||
var dst T
|
||||
err := pcopy.Copy(&dst, src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if &dst == nil {
|
||||
return nil, fmt.Errorf("copy result cannot be nil")
|
||||
}
|
||||
return &dst, nil
|
||||
}
|
||||
@@ -26,6 +26,12 @@ func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *AudioRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
@@ -321,8 +321,14 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
return &tokenCountMeta
|
||||
}
|
||||
|
||||
func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
|
||||
return claudeRequest.Stream
|
||||
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
c.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||
|
||||
@@ -48,6 +48,12 @@ func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *EmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EmbeddingRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return make([]string, 0)
|
||||
|
||||
@@ -73,6 +73,10 @@ func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||
// GeminiChatRequest does not have a model field, so this method does nothing.
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||
var tools []GeminiChatTool
|
||||
if strings.HasSuffix(string(r.Tools), "[") {
|
||||
@@ -312,10 +316,61 @@ type GeminiEmbeddingRequest struct {
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
// Gemini embedding requests are not streamed
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var inputTexts []string
|
||||
for _, part := range r.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: inputText,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type GeminiBatchEmbeddingRequest struct {
|
||||
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
// Gemini batch embedding requests are not streamed
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var inputTexts []string
|
||||
for _, request := range r.Requests {
|
||||
meta := request.GetTokenCountMeta()
|
||||
if meta != nil && meta.CombineText != "" {
|
||||
inputTexts = append(inputTexts, meta.CombineText)
|
||||
}
|
||||
}
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: inputText,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
for _, req := range r.Requests {
|
||||
req.SetModelName(modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N uint `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style json.RawMessage `json:"style,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style json.RawMessage `json:"style,omitempty"`
|
||||
User json.RawMessage `json:"user,omitempty"`
|
||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||
Background json.RawMessage `json:"background,omitempty"`
|
||||
@@ -63,6 +63,12 @@ func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ImageRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
i.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
|
||||
@@ -183,6 +183,12 @@ func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.Marshal(r)
|
||||
@@ -841,6 +847,12 @@ func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort string `json:"effort,omitempty"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type Request interface {
|
||||
GetTokenCountMeta() *types.TokenCountMeta
|
||||
IsStream(c *gin.Context) bool
|
||||
SetModelName(modelName string)
|
||||
}
|
||||
|
||||
type BaseRequest struct {
|
||||
@@ -18,7 +19,7 @@ func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
TokenType: types.TokenTypeTokenizer,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BaseRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
func (b *BaseRequest) SetModelName(modelName string) {}
|
||||
|
||||
@@ -37,6 +37,12 @@ func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RerankRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RerankRequest) GetReturnDocuments() bool {
|
||||
if r.ReturnDocuments == nil {
|
||||
return false
|
||||
|
||||
9
go.mod
9
go.mod
@@ -44,7 +44,11 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.2.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.2.3 // indirect
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
github.com/antlabs/pcopy v0.1.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
|
||||
@@ -69,6 +73,8 @@ require (
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/huandu/xstrings v1.3.3 // indirect
|
||||
github.com/imdario/mergo v0.3.11 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.1 // indirect
|
||||
@@ -79,11 +85,14 @@ require (
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/copystructure v1.0.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/spf13/cast v1.3.1 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -1,11 +1,19 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g=
|
||||
github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||
github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA=
|
||||
github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/antlabs/pcopy v0.1.5 h1:5Fa1ExY9T6ar3ysAi4rzB5jiYg72Innm+/ESEIOSHvQ=
|
||||
github.com/antlabs/pcopy v0.1.5/go.mod h1:2FvdkPD3cFiM1CjGuXFCDQZqhKVcLI7IzeSJ2xUIOOI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||
@@ -102,6 +110,7 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
@@ -112,6 +121,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
|
||||
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA=
|
||||
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -150,8 +163,12 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
|
||||
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY=
|
||||
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -184,14 +201,19 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
|
||||
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -229,25 +251,36 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
|
||||
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
|
||||
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -256,18 +289,29 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
@@ -282,6 +326,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -16,12 +17,17 @@ import (
|
||||
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
audioRequest, ok := info.Request.(*dto.AudioRequest)
|
||||
audioReq, ok := info.Request.(*dto.AudioRequest)
|
||||
if !ok {
|
||||
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, audioRequest)
|
||||
request, err := common.DeepCopy(audioReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -32,7 +38,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -21,13 +21,18 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.ClaudeRequest)
|
||||
claudeReq, ok := info.Request.(*dto.ClaudeRequest)
|
||||
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
request, err := common.DeepCopy(claudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -38,30 +43,30 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
if request.MaxTokens == 0 {
|
||||
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
if request.MaxTokens < 1280 {
|
||||
request.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
request.TopP = 0
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
info.UpstreamModelName = textRequest.Model
|
||||
request.Model = strings.TrimSuffix(request.Model, "-thinking")
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
@@ -72,7 +77,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -158,7 +158,14 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
if streamSupportedChannels[channelMeta.ChannelType] {
|
||||
channelMeta.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
info.ChannelMeta = channelMeta
|
||||
|
||||
// reset some fields based on channel meta
|
||||
// 重置某些字段,例如模型名称等
|
||||
if info.Request != nil {
|
||||
info.Request.SetModelName(info.OriginModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) ToString() string {
|
||||
|
||||
@@ -16,15 +16,19 @@ import (
|
||||
)
|
||||
|
||||
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, embeddingRequest)
|
||||
request, err := common.DeepCopy(embeddingReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -35,7 +39,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -53,13 +53,18 @@ func trimModelThinking(modelName string) string {
|
||||
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
request, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(geminiReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err := helper.ModelMappedHelper(c, info, request)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -170,7 +175,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||
info.IsGeminiBatchEmbedding = isBatch
|
||||
|
||||
var req any
|
||||
var req dto.Request
|
||||
var err error
|
||||
var inputTexts []string
|
||||
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/dto"
|
||||
common2 "one-api/logger"
|
||||
"one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -54,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case types.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
request.SetModelName(info.UpstreamModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,16 +20,19 @@ import (
|
||||
)
|
||||
|
||||
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
imageRequest, ok := info.Request.(*dto.ImageRequest)
|
||||
|
||||
imageReq, ok := info.Request.(*dto.ImageRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, imageRequest)
|
||||
request, err := common.DeepCopy(imageReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -49,7 +52,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
@@ -102,21 +105,21 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
|
||||
usage.(*dto.Usage).TotalTokens = int(request.N)
|
||||
}
|
||||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||
usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
|
||||
usage.(*dto.Usage).PromptTokens = int(request.N)
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
if imageRequest.Quality == "hd" {
|
||||
if request.Quality == "hd" {
|
||||
quality = "hd"
|
||||
}
|
||||
|
||||
var logContent string
|
||||
|
||||
if len(imageRequest.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
|
||||
@@ -25,38 +25,41 @@ import (
|
||||
)
|
||||
|
||||
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
|
||||
textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
if !ok {
|
||||
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
request, err := common.DeepCopy(textReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
if request.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
includeUsage := true
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil {
|
||||
includeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
if request.StreamOptions != nil {
|
||||
includeUsage = request.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !info.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
if !info.SupportStreamOptions || !request.Stream {
|
||||
request.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
request.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
@@ -81,7 +84,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -16,23 +16,20 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||
token += tkm
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
rerankRequest, ok := info.Request.(*dto.RerankRequest)
|
||||
rerankReq, ok := info.Request.(*dto.RerankRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, rerankRequest)
|
||||
request, err := common.DeepCopy(rerankReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -51,7 +48,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user