Merge branch 'alpha' into mutil_key_channel
# Conflicts: # controller/channel.go # docker-compose.yml # web/src/components/table/ChannelsTable.js # web/src/pages/Channel/EditChannel.js
This commit is contained in:
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
|||||||
# 调试相关配置
|
# 调试相关配置
|
||||||
# 启用pprof
|
# 启用pprof
|
||||||
# ENABLE_PPROF=true
|
# ENABLE_PPROF=true
|
||||||
|
# 启用调试模式
|
||||||
|
# DEBUG=true
|
||||||
|
|
||||||
# 数据库相关配置
|
# 数据库相关配置
|
||||||
# 数据库连接字符串
|
# 数据库连接字符串
|
||||||
@@ -41,6 +43,14 @@
|
|||||||
# 更新任务启用
|
# 更新任务启用
|
||||||
# UPDATE_TASK=true
|
# UPDATE_TASK=true
|
||||||
|
|
||||||
|
# 对话超时设置
|
||||||
|
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||||
|
# RELAY_TIMEOUT=0
|
||||||
|
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||||
|
# STREAMING_TIMEOUT=120
|
||||||
|
|
||||||
|
# Gemini 识别图片 最大图片数量
|
||||||
|
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||||
|
|
||||||
# 会话密钥
|
# 会话密钥
|
||||||
# SESSION_SECRET=random_string
|
# SESSION_SECRET=random_string
|
||||||
@@ -58,8 +68,6 @@
|
|||||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||||
# DIFY_DEBUG=true
|
# DIFY_DEBUG=true
|
||||||
# 设置流式一次回复的超时时间
|
|
||||||
# STREAMING_TIMEOUT=90
|
|
||||||
|
|
||||||
|
|
||||||
# 节点类型
|
# 节点类型
|
||||||
|
|||||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
### PR 类型
|
||||||
|
|
||||||
|
- [ ] Bug 修复
|
||||||
|
- [ ] 新功能
|
||||||
|
- [ ] 文档更新
|
||||||
|
- [ ] 其他
|
||||||
|
|
||||||
|
### PR 是否包含破坏性更新?
|
||||||
|
|
||||||
|
- [ ] 是
|
||||||
|
- [ ] 否
|
||||||
|
|
||||||
|
### PR 描述
|
||||||
|
|
||||||
|
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||||
|
|
||||||
|
### **重要提示**
|
||||||
|
|
||||||
|
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||||
1
.github/workflows/macos-release.yml
vendored
1
.github/workflows/macos-release.yml
vendored
@@ -26,6 +26,7 @@ jobs:
|
|||||||
- name: Build Frontend
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
|
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||||
run: |
|
run: |
|
||||||
cd web
|
cd web
|
||||||
bun install
|
bun install
|
||||||
|
|||||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Check PR Branching Strategy
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, edited]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-branching-strategy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Enforce branching strategy
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||||
|
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||||
|
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||||
|
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Branching strategy check passed."
|
||||||
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
|||||||
|
|
||||||
FROM alpine
|
FROM alpine
|
||||||
|
|
||||||
RUN apk update \
|
RUN apk upgrade --no-cache \
|
||||||
&& apk upgrade \
|
|
||||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||||
&& update-ca-certificates
|
&& update-ca-certificates
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
|||||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
|
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
|
||||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||||
|
|||||||
@@ -27,9 +27,6 @@
|
|||||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://coderabbit.ai">
|
|
||||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -103,7 +100,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
|||||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
|
||||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||||
|
|||||||
71
common/api_type.go
Normal file
71
common/api_type.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
func ChannelType2APIType(channelType int) (int, bool) {
|
||||||
|
apiType := -1
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeOpenAI:
|
||||||
|
apiType = constant.APITypeOpenAI
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
apiType = constant.APITypeAnthropic
|
||||||
|
case constant.ChannelTypeBaidu:
|
||||||
|
apiType = constant.APITypeBaidu
|
||||||
|
case constant.ChannelTypePaLM:
|
||||||
|
apiType = constant.APITypePaLM
|
||||||
|
case constant.ChannelTypeZhipu:
|
||||||
|
apiType = constant.APITypeZhipu
|
||||||
|
case constant.ChannelTypeAli:
|
||||||
|
apiType = constant.APITypeAli
|
||||||
|
case constant.ChannelTypeXunfei:
|
||||||
|
apiType = constant.APITypeXunfei
|
||||||
|
case constant.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = constant.APITypeAIProxyLibrary
|
||||||
|
case constant.ChannelTypeTencent:
|
||||||
|
apiType = constant.APITypeTencent
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
apiType = constant.APITypeGemini
|
||||||
|
case constant.ChannelTypeZhipu_v4:
|
||||||
|
apiType = constant.APITypeZhipuV4
|
||||||
|
case constant.ChannelTypeOllama:
|
||||||
|
apiType = constant.APITypeOllama
|
||||||
|
case constant.ChannelTypePerplexity:
|
||||||
|
apiType = constant.APITypePerplexity
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
apiType = constant.APITypeAws
|
||||||
|
case constant.ChannelTypeCohere:
|
||||||
|
apiType = constant.APITypeCohere
|
||||||
|
case constant.ChannelTypeDify:
|
||||||
|
apiType = constant.APITypeDify
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
apiType = constant.APITypeJina
|
||||||
|
case constant.ChannelCloudflare:
|
||||||
|
apiType = constant.APITypeCloudflare
|
||||||
|
case constant.ChannelTypeSiliconFlow:
|
||||||
|
apiType = constant.APITypeSiliconFlow
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
apiType = constant.APITypeVertexAi
|
||||||
|
case constant.ChannelTypeMistral:
|
||||||
|
apiType = constant.APITypeMistral
|
||||||
|
case constant.ChannelTypeDeepSeek:
|
||||||
|
apiType = constant.APITypeDeepSeek
|
||||||
|
case constant.ChannelTypeMokaAI:
|
||||||
|
apiType = constant.APITypeMokaAI
|
||||||
|
case constant.ChannelTypeVolcEngine:
|
||||||
|
apiType = constant.APITypeVolcEngine
|
||||||
|
case constant.ChannelTypeBaiduV2:
|
||||||
|
apiType = constant.APITypeBaiduV2
|
||||||
|
case constant.ChannelTypeOpenRouter:
|
||||||
|
apiType = constant.APITypeOpenRouter
|
||||||
|
case constant.ChannelTypeXinference:
|
||||||
|
apiType = constant.APITypeXinference
|
||||||
|
case constant.ChannelTypeXai:
|
||||||
|
apiType = constant.APITypeXai
|
||||||
|
case constant.ChannelTypeCoze:
|
||||||
|
apiType = constant.APITypeCoze
|
||||||
|
}
|
||||||
|
if apiType == -1 {
|
||||||
|
return constant.APITypeOpenAI, false
|
||||||
|
}
|
||||||
|
return apiType, true
|
||||||
|
}
|
||||||
@@ -193,107 +193,3 @@ const (
|
|||||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
ChannelStatusAutoDisabled = 3
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ChannelTypeUnknown = 0
|
|
||||||
ChannelTypeOpenAI = 1
|
|
||||||
ChannelTypeMidjourney = 2
|
|
||||||
ChannelTypeAzure = 3
|
|
||||||
ChannelTypeOllama = 4
|
|
||||||
ChannelTypeMidjourneyPlus = 5
|
|
||||||
ChannelTypeOpenAIMax = 6
|
|
||||||
ChannelTypeOhMyGPT = 7
|
|
||||||
ChannelTypeCustom = 8
|
|
||||||
ChannelTypeAILS = 9
|
|
||||||
ChannelTypeAIProxy = 10
|
|
||||||
ChannelTypePaLM = 11
|
|
||||||
ChannelTypeAPI2GPT = 12
|
|
||||||
ChannelTypeAIGC2D = 13
|
|
||||||
ChannelTypeAnthropic = 14
|
|
||||||
ChannelTypeBaidu = 15
|
|
||||||
ChannelTypeZhipu = 16
|
|
||||||
ChannelTypeAli = 17
|
|
||||||
ChannelTypeXunfei = 18
|
|
||||||
ChannelType360 = 19
|
|
||||||
ChannelTypeOpenRouter = 20
|
|
||||||
ChannelTypeAIProxyLibrary = 21
|
|
||||||
ChannelTypeFastGPT = 22
|
|
||||||
ChannelTypeTencent = 23
|
|
||||||
ChannelTypeGemini = 24
|
|
||||||
ChannelTypeMoonshot = 25
|
|
||||||
ChannelTypeZhipu_v4 = 26
|
|
||||||
ChannelTypePerplexity = 27
|
|
||||||
ChannelTypeLingYiWanWu = 31
|
|
||||||
ChannelTypeAws = 33
|
|
||||||
ChannelTypeCohere = 34
|
|
||||||
ChannelTypeMiniMax = 35
|
|
||||||
ChannelTypeSunoAPI = 36
|
|
||||||
ChannelTypeDify = 37
|
|
||||||
ChannelTypeJina = 38
|
|
||||||
ChannelCloudflare = 39
|
|
||||||
ChannelTypeSiliconFlow = 40
|
|
||||||
ChannelTypeVertexAi = 41
|
|
||||||
ChannelTypeMistral = 42
|
|
||||||
ChannelTypeDeepSeek = 43
|
|
||||||
ChannelTypeMokaAI = 44
|
|
||||||
ChannelTypeVolcEngine = 45
|
|
||||||
ChannelTypeBaiduV2 = 46
|
|
||||||
ChannelTypeXinference = 47
|
|
||||||
ChannelTypeXai = 48
|
|
||||||
ChannelTypeCoze = 49
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
|
||||||
"", // 0
|
|
||||||
"https://api.openai.com", // 1
|
|
||||||
"https://oa.api2d.net", // 2
|
|
||||||
"", // 3
|
|
||||||
"http://localhost:11434", // 4
|
|
||||||
"https://api.openai-sb.com", // 5
|
|
||||||
"https://api.openaimax.com", // 6
|
|
||||||
"https://api.ohmygpt.com", // 7
|
|
||||||
"", // 8
|
|
||||||
"https://api.caipacity.com", // 9
|
|
||||||
"https://api.aiproxy.io", // 10
|
|
||||||
"", // 11
|
|
||||||
"https://api.api2gpt.com", // 12
|
|
||||||
"https://api.aigc2d.com", // 13
|
|
||||||
"https://api.anthropic.com", // 14
|
|
||||||
"https://aip.baidubce.com", // 15
|
|
||||||
"https://open.bigmodel.cn", // 16
|
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
|
||||||
"", // 18
|
|
||||||
"https://api.360.cn", // 19
|
|
||||||
"https://openrouter.ai/api", // 20
|
|
||||||
"https://api.aiproxy.io", // 21
|
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
|
||||||
"https://hunyuan.tencentcloudapi.com", //23
|
|
||||||
"https://generativelanguage.googleapis.com", //24
|
|
||||||
"https://api.moonshot.cn", //25
|
|
||||||
"https://open.bigmodel.cn", //26
|
|
||||||
"https://api.perplexity.ai", //27
|
|
||||||
"", //28
|
|
||||||
"", //29
|
|
||||||
"", //30
|
|
||||||
"https://api.lingyiwanwu.com", //31
|
|
||||||
"", //32
|
|
||||||
"", //33
|
|
||||||
"https://api.cohere.ai", //34
|
|
||||||
"https://api.minimax.chat", //35
|
|
||||||
"", //36
|
|
||||||
"https://api.dify.ai", //37
|
|
||||||
"https://api.jina.ai", //38
|
|
||||||
"https://api.cloudflare.com", //39
|
|
||||||
"https://api.siliconflow.cn", //40
|
|
||||||
"", //41
|
|
||||||
"https://api.mistral.ai", //42
|
|
||||||
"https://api.deepseek.com", //43
|
|
||||||
"https://api.moka.ai", //44
|
|
||||||
"https://ark.cn-beijing.volces.com", //45
|
|
||||||
"https://qianfan.baidubce.com", //46
|
|
||||||
"", //47
|
|
||||||
"https://api.x.ai", //48
|
|
||||||
"https://api.coze.cn", //49
|
|
||||||
}
|
|
||||||
|
|||||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||||
|
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||||
|
var endpointTypes []constant.EndpointType
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||||
|
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||||
|
//case constant.ChannelTypeSunoAPI:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||||
|
//case constant.ChannelTypeKling:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||||
|
//case constant.ChannelTypeJimeng:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
default:
|
||||||
|
if IsOpenAIResponseOnlyModel(modelName) {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||||
|
} else {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if IsImageGenerationModel(modelName) {
|
||||||
|
// add to first
|
||||||
|
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||||
|
}
|
||||||
|
return endpointTypes
|
||||||
|
}
|
||||||
@@ -2,10 +2,11 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const KeyRequestBody = "key_request_body"
|
const KeyRequestBody = "key_request_body"
|
||||||
@@ -31,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
}
|
}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = UnmarshalJson(requestBody, &v)
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
// skip for now
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
@@ -43,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||||
|
c.Set(string(key), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||||
|
return c.Get(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||||
|
return c.GetString(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||||
|
return c.GetInt(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||||
|
return c.GetBool(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||||
|
return c.GetStringSlice(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||||
|
return c.GetStringMap(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||||
|
return c.GetTime(string(key))
|
||||||
|
}
|
||||||
|
|||||||
57
common/http.go
Normal file
57
common/http.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||||
|
if httpResponse == nil || httpResponse.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := httpResponse.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
SysError("failed to close response body: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||||
|
if c.Writer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body := io.NopCloser(bytes.NewBuffer(data))
|
||||||
|
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the httpClient will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
if src != nil {
|
||||||
|
for k, v := range src.Header {
|
||||||
|
// avoid setting Content-Length
|
||||||
|
if k == "Content-Length" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set Content-Length header manually BEFORE calling WriteHeader
|
||||||
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||||
|
|
||||||
|
// Write header with status code (this sends the headers)
|
||||||
|
if src != nil {
|
||||||
|
c.Writer.WriteHeader(src.StatusCode)
|
||||||
|
} else {
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.Copy(c.Writer, body)
|
||||||
|
if err != nil {
|
||||||
|
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"one-api/constant"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -24,7 +25,7 @@ func printHelp() {
|
|||||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadEnv() {
|
func InitEnv() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
@@ -95,4 +96,25 @@ func LoadEnv() {
|
|||||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||||
|
|
||||||
|
initConstantEnv()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initConstantEnv() {
|
||||||
|
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||||
|
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||||
|
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||||
|
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||||
|
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||||
|
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||||
|
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||||
|
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||||
|
// 是否启用错误日志
|
||||||
|
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,12 +5,16 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DecodeJson(data []byte, v any) error {
|
func UnmarshalJson(data []byte, v any) error {
|
||||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
return json.Unmarshal(data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeJsonStr(data string, v any) error {
|
func UnmarshalJsonStr(data string, v any) error {
|
||||||
return DecodeJson(StringToByteSlice(data), v)
|
return json.Unmarshal(StringToByteSlice(data), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||||
|
return json.NewDecoder(reader).Decode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodeJson(v any) ([]byte, error) {
|
func EncodeJson(v any) ([]byte, error) {
|
||||||
|
|||||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||||
|
OpenAIResponseOnlyModels = []string{
|
||||||
|
"o3-pro",
|
||||||
|
"o3-deep-research",
|
||||||
|
"o4-mini-deep-research",
|
||||||
|
}
|
||||||
|
ImageGenerationModels = []string{
|
||||||
|
"dall-e-3",
|
||||||
|
"dall-e-2",
|
||||||
|
"gpt-image-1",
|
||||||
|
"prefix:imagen-",
|
||||||
|
"flux-",
|
||||||
|
"flux.1-",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||||
|
for _, m := range OpenAIResponseOnlyModels {
|
||||||
|
if strings.Contains(modelName, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsImageGenerationModel(modelName string) bool {
|
||||||
|
modelName = strings.ToLower(modelName)
|
||||||
|
for _, m := range ImageGenerationModels {
|
||||||
|
if strings.Contains(modelName, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
62
common/page_info.go
Normal file
62
common/page_info.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PageInfo struct {
|
||||||
|
Page int `json:"page"` // page num 页码
|
||||||
|
PageSize int `json:"page_size"` // page size 页大小
|
||||||
|
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
||||||
|
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
||||||
|
|
||||||
|
Total int `json:"total"` // 总条数,后设置
|
||||||
|
Items any `json:"items"` // 数据,后设置
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetStartIdx() int {
|
||||||
|
return (p.Page - 1) * p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetEndIdx() int {
|
||||||
|
return p.Page * p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetPageSize() int {
|
||||||
|
return p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetPage() int {
|
||||||
|
return p.Page
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) SetTotal(total int) {
|
||||||
|
p.Total = total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) SetItems(items any) {
|
||||||
|
p.Items = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||||
|
pageInfo := &PageInfo{}
|
||||||
|
err := c.BindQuery(pageInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if pageInfo.Page < 1 {
|
||||||
|
// 兼容
|
||||||
|
page, _ := strconv.Atoi(c.Query("p"))
|
||||||
|
if page != 0 {
|
||||||
|
pageInfo.Page = page
|
||||||
|
} else {
|
||||||
|
pageInfo.Page = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
pageInfo.PageSize = ItemsPerPage
|
||||||
|
}
|
||||||
|
return pageInfo, nil
|
||||||
|
}
|
||||||
@@ -16,6 +16,10 @@ import (
|
|||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
var RedisEnabled = true
|
var RedisEnabled = true
|
||||||
|
|
||||||
|
func RedisKeyCacheSeconds() int {
|
||||||
|
return SyncFrequency
|
||||||
|
}
|
||||||
|
|
||||||
// InitRedisClient This function is called after init()
|
// InitRedisClient This function is called after init()
|
||||||
func InitRedisClient() (err error) {
|
func InitRedisClient() (err error) {
|
||||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||||
@@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
|||||||
|
|
||||||
txn := RDB.TxPipeline()
|
txn := RDB.TxPipeline()
|
||||||
txn.HSet(ctx, key, data)
|
txn.HSet(ctx, key, data)
|
||||||
|
|
||||||
|
// 只有在 expiration 大于 0 时才设置过期时间
|
||||||
|
if expiration > 0 {
|
||||||
txn.Expire(ctx, key, expiration)
|
txn.Expire(ctx, key, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
_, err := txn.Exec(ctx)
|
_, err := txn.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
output, err := c.Output()
|
output, err := c.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
}
|
}
|
||||||
|
durationStr := string(bytes.TrimSpace(output))
|
||||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
if durationStr == "N/A" {
|
||||||
|
// Create a temporary output file name
|
||||||
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||||
|
}
|
||||||
|
tmpName := tmpFp.Name()
|
||||||
|
// Close immediately so ffmpeg can open the file on Windows.
|
||||||
|
_ = tmpFp.Close()
|
||||||
|
defer os.Remove(tmpName)
|
||||||
|
|
||||||
|
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||||
|
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||||
|
if err := ffmpegCmd.Run(); err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate the duration of the new file
|
||||||
|
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||||
|
}
|
||||||
|
durationStr = string(bytes.TrimSpace(output))
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||||
|
func BuildURL(base string, endpoint string) string {
|
||||||
|
u, err := url.Parse(base)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
end := endpoint
|
||||||
|
if end == "" {
|
||||||
|
end = "/"
|
||||||
|
}
|
||||||
|
ref, err := url.Parse(end)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
return u.ResolveReference(ref).String()
|
||||||
}
|
}
|
||||||
|
|||||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# constant 包 (`/constant`)
|
||||||
|
|
||||||
|
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||||
|
|
||||||
|
## 当前文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|----------------------|---------------------------------------------------------------------|
|
||||||
|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||||
|
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||||
|
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||||
|
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||||
|
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||||
|
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||||
|
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||||
|
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||||
|
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||||
|
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||||
|
|
||||||
|
## 使用约定
|
||||||
|
|
||||||
|
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||||
|
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||||
|
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||||
|
|
||||||
|
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||||
34
constant/api_type.go
Normal file
34
constant/api_type.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
APITypeOpenAI = iota
|
||||||
|
APITypeAnthropic
|
||||||
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
|
APITypeZhipu
|
||||||
|
APITypeAli
|
||||||
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
|
APITypeGemini
|
||||||
|
APITypeZhipuV4
|
||||||
|
APITypeOllama
|
||||||
|
APITypePerplexity
|
||||||
|
APITypeAws
|
||||||
|
APITypeCohere
|
||||||
|
APITypeDify
|
||||||
|
APITypeJina
|
||||||
|
APITypeCloudflare
|
||||||
|
APITypeSiliconFlow
|
||||||
|
APITypeVertexAi
|
||||||
|
APITypeMistral
|
||||||
|
APITypeDeepSeek
|
||||||
|
APITypeMokaAI
|
||||||
|
APITypeVolcEngine
|
||||||
|
APITypeBaiduV2
|
||||||
|
APITypeOpenRouter
|
||||||
|
APITypeXinference
|
||||||
|
APITypeXai
|
||||||
|
APITypeCoze
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
)
|
||||||
@@ -1,14 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import "one-api/common"
|
|
||||||
|
|
||||||
var (
|
|
||||||
TokenCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cache keys
|
// Cache keys
|
||||||
const (
|
const (
|
||||||
UserGroupKeyFmt = "user_group:%d"
|
UserGroupKeyFmt = "user_group:%d"
|
||||||
|
|||||||
109
constant/channel.go
Normal file
109
constant/channel.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChannelTypeUnknown = 0
|
||||||
|
ChannelTypeOpenAI = 1
|
||||||
|
ChannelTypeMidjourney = 2
|
||||||
|
ChannelTypeAzure = 3
|
||||||
|
ChannelTypeOllama = 4
|
||||||
|
ChannelTypeMidjourneyPlus = 5
|
||||||
|
ChannelTypeOpenAIMax = 6
|
||||||
|
ChannelTypeOhMyGPT = 7
|
||||||
|
ChannelTypeCustom = 8
|
||||||
|
ChannelTypeAILS = 9
|
||||||
|
ChannelTypeAIProxy = 10
|
||||||
|
ChannelTypePaLM = 11
|
||||||
|
ChannelTypeAPI2GPT = 12
|
||||||
|
ChannelTypeAIGC2D = 13
|
||||||
|
ChannelTypeAnthropic = 14
|
||||||
|
ChannelTypeBaidu = 15
|
||||||
|
ChannelTypeZhipu = 16
|
||||||
|
ChannelTypeAli = 17
|
||||||
|
ChannelTypeXunfei = 18
|
||||||
|
ChannelType360 = 19
|
||||||
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
|
ChannelTypeFastGPT = 22
|
||||||
|
ChannelTypeTencent = 23
|
||||||
|
ChannelTypeGemini = 24
|
||||||
|
ChannelTypeMoonshot = 25
|
||||||
|
ChannelTypeZhipu_v4 = 26
|
||||||
|
ChannelTypePerplexity = 27
|
||||||
|
ChannelTypeLingYiWanWu = 31
|
||||||
|
ChannelTypeAws = 33
|
||||||
|
ChannelTypeCohere = 34
|
||||||
|
ChannelTypeMiniMax = 35
|
||||||
|
ChannelTypeSunoAPI = 36
|
||||||
|
ChannelTypeDify = 37
|
||||||
|
ChannelTypeJina = 38
|
||||||
|
ChannelCloudflare = 39
|
||||||
|
ChannelTypeSiliconFlow = 40
|
||||||
|
ChannelTypeVertexAi = 41
|
||||||
|
ChannelTypeMistral = 42
|
||||||
|
ChannelTypeDeepSeek = 43
|
||||||
|
ChannelTypeMokaAI = 44
|
||||||
|
ChannelTypeVolcEngine = 45
|
||||||
|
ChannelTypeBaiduV2 = 46
|
||||||
|
ChannelTypeXinference = 47
|
||||||
|
ChannelTypeXai = 48
|
||||||
|
ChannelTypeCoze = 49
|
||||||
|
ChannelTypeKling = 50
|
||||||
|
ChannelTypeJimeng = 51
|
||||||
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
var ChannelBaseURLs = []string{
|
||||||
|
"", // 0
|
||||||
|
"https://api.openai.com", // 1
|
||||||
|
"https://oa.api2d.net", // 2
|
||||||
|
"", // 3
|
||||||
|
"http://localhost:11434", // 4
|
||||||
|
"https://api.openai-sb.com", // 5
|
||||||
|
"https://api.openaimax.com", // 6
|
||||||
|
"https://api.ohmygpt.com", // 7
|
||||||
|
"", // 8
|
||||||
|
"https://api.caipacity.com", // 9
|
||||||
|
"https://api.aiproxy.io", // 10
|
||||||
|
"", // 11
|
||||||
|
"https://api.api2gpt.com", // 12
|
||||||
|
"https://api.aigc2d.com", // 13
|
||||||
|
"https://api.anthropic.com", // 14
|
||||||
|
"https://aip.baidubce.com", // 15
|
||||||
|
"https://open.bigmodel.cn", // 16
|
||||||
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
|
"", // 18
|
||||||
|
"https://api.360.cn", // 19
|
||||||
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
|
"https://hunyuan.tencentcloudapi.com", //23
|
||||||
|
"https://generativelanguage.googleapis.com", //24
|
||||||
|
"https://api.moonshot.cn", //25
|
||||||
|
"https://open.bigmodel.cn", //26
|
||||||
|
"https://api.perplexity.ai", //27
|
||||||
|
"", //28
|
||||||
|
"", //29
|
||||||
|
"", //30
|
||||||
|
"https://api.lingyiwanwu.com", //31
|
||||||
|
"", //32
|
||||||
|
"", //33
|
||||||
|
"https://api.cohere.ai", //34
|
||||||
|
"https://api.minimax.chat", //35
|
||||||
|
"", //36
|
||||||
|
"https://api.dify.ai", //37
|
||||||
|
"https://api.jina.ai", //38
|
||||||
|
"https://api.cloudflare.com", //39
|
||||||
|
"https://api.siliconflow.cn", //40
|
||||||
|
"", //41
|
||||||
|
"https://api.mistral.ai", //42
|
||||||
|
"https://api.deepseek.com", //43
|
||||||
|
"https://api.moka.ai", //44
|
||||||
|
"https://ark.cn-beijing.volces.com", //45
|
||||||
|
"https://qianfan.baidubce.com", //46
|
||||||
|
"", //47
|
||||||
|
"https://api.x.ai", //48
|
||||||
|
"https://api.coze.cn", //49
|
||||||
|
"https://api.klingai.com", //50
|
||||||
|
"https://visual.volcengineapi.com", //51
|
||||||
|
}
|
||||||
@@ -1,10 +1,35 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ContextKeyRequestStartTime = "request_start_time"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyUserSetting = "user_setting"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
ContextKeyUserQuota = "user_quota"
|
|
||||||
ContextKeyUserStatus = "user_status"
|
/* token related keys */
|
||||||
ContextKeyUserEmail = "user_email"
|
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||||
ContextKeyUserGroup = "user_group"
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
|
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||||
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|
||||||
|
/* channel related keys */
|
||||||
|
ContextKeyBaseUrl ContextKey = "base_url"
|
||||||
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyParamOverride ContextKey = "param_override"
|
||||||
|
|
||||||
|
/* user related keys */
|
||||||
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
ContextKeyUserSetting ContextKey = "user_setting"
|
||||||
|
ContextKeyUserQuota ContextKey = "user_quota"
|
||||||
|
ContextKeyUserStatus ContextKey = "user_status"
|
||||||
|
ContextKeyUserEmail ContextKey = "user_email"
|
||||||
|
ContextKeyUserGroup ContextKey = "user_group"
|
||||||
|
ContextKeyUsingGroup ContextKey = "group"
|
||||||
|
ContextKeyUserName ContextKey = "username"
|
||||||
)
|
)
|
||||||
|
|||||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type EndpointType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointTypeOpenAI EndpointType = "openai"
|
||||||
|
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||||
|
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||||
|
EndpointTypeGemini EndpointType = "gemini"
|
||||||
|
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||||
|
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||||
|
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||||
|
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||||
|
//EndpointTypeKling EndpointType = "kling"
|
||||||
|
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||||
|
)
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
var StreamingTimeout int
|
var StreamingTimeout int
|
||||||
var DifyDebug bool
|
var DifyDebug bool
|
||||||
var MaxFileDownloadMB int
|
var MaxFileDownloadMB int
|
||||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
|||||||
var NotificationLimitDurationMinute int
|
var NotificationLimitDurationMinute int
|
||||||
var GenerateDefaultToken bool
|
var GenerateDefaultToken bool
|
||||||
var ErrorLogEnabled bool
|
var ErrorLogEnabled bool
|
||||||
|
|
||||||
//var GeminiModelMap = map[string]string{
|
|
||||||
// "gemini-1.0-pro": "v1",
|
|
||||||
//}
|
|
||||||
|
|
||||||
func InitEnv() {
|
|
||||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
|
||||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
|
||||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
|
||||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
|
||||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
|
||||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
|
||||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
|
||||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
|
||||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
|
||||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
|
||||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
|
||||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
|
||||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
|
||||||
// 是否启用错误日志
|
|
||||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
|
||||||
|
|
||||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
|
||||||
//if modelVersionMapStr == "" {
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
|
||||||
// parts := strings.Split(pair, ":")
|
|
||||||
// if len(parts) == 2 {
|
|
||||||
// GeminiModelMap[parts[0]] = parts[1]
|
|
||||||
// } else {
|
|
||||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ const (
|
|||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
MjActionSwapFace = "SWAP_FACE"
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
MjActionUpload = "UPLOAD"
|
MjActionUpload = "UPLOAD"
|
||||||
|
MjActionVideo = "VIDEO"
|
||||||
|
MjActionEdits = "EDITS"
|
||||||
)
|
)
|
||||||
|
|
||||||
var MidjourneyModel2Action = map[string]string{
|
var MidjourneyModel2Action = map[string]string{
|
||||||
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
|
|||||||
"mj_pan": MjActionPan,
|
"mj_pan": MjActionPan,
|
||||||
"swap_face": MjActionSwapFace,
|
"swap_face": MjActionSwapFace,
|
||||||
"mj_upload": MjActionUpload,
|
"mj_upload": MjActionUpload,
|
||||||
|
"mj_video": MjActionVideo,
|
||||||
|
"mj_edits": MjActionEdits,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,16 @@ type TaskPlatform string
|
|||||||
const (
|
const (
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
TaskPlatformMidjourney = "mj"
|
TaskPlatformMidjourney = "mj"
|
||||||
|
TaskPlatformKling TaskPlatform = "kling"
|
||||||
|
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SunoActionMusic = "MUSIC"
|
SunoActionMusic = "MUSIC"
|
||||||
SunoActionLyrics = "LYRICS"
|
SunoActionLyrics = "LYRICS"
|
||||||
|
|
||||||
|
TaskActionGenerate = "generate"
|
||||||
|
TaskActionTextGenerate = "textGenerate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var SunoModel2Action = map[string]string{
|
var SunoModel2Action = map[string]string{
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -304,34 +307,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
|||||||
return balance, nil
|
return balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||||
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type MoonshotBalanceData struct {
|
||||||
|
AvailableBalance float64 `json:"available_balance"`
|
||||||
|
VoucherBalance float64 `json:"voucher_balance"`
|
||||||
|
CashBalance float64 `json:"cash_balance"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MoonshotBalanceResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data MoonshotBalanceData `json:"data"`
|
||||||
|
Scode string `json:"scode"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
response := MoonshotBalanceResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !response.Status || response.Code != 0 {
|
||||||
|
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||||
|
}
|
||||||
|
availableBalanceCny := response.Data.AvailableBalance
|
||||||
|
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||||
|
channel.UpdateBalance(availableBalanceUsd)
|
||||||
|
return availableBalanceUsd, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = &baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case constant.ChannelTypeOpenAI:
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
//case common.ChannelTypeOpenAISB:
|
//case common.ChannelTypeOpenAISB:
|
||||||
// return updateChannelOpenAISBBalance(channel)
|
// return updateChannelOpenAISBBalance(channel)
|
||||||
case common.ChannelTypeAIProxy:
|
case constant.ChannelTypeAIProxy:
|
||||||
return updateChannelAIProxyBalance(channel)
|
return updateChannelAIProxyBalance(channel)
|
||||||
case common.ChannelTypeAPI2GPT:
|
case constant.ChannelTypeAPI2GPT:
|
||||||
return updateChannelAPI2GPTBalance(channel)
|
return updateChannelAPI2GPTBalance(channel)
|
||||||
case common.ChannelTypeAIGC2D:
|
case constant.ChannelTypeAIGC2D:
|
||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
case common.ChannelTypeSiliconFlow:
|
case constant.ChannelTypeSiliconFlow:
|
||||||
return updateChannelSiliconFlowBalance(channel)
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
case common.ChannelTypeDeepSeek:
|
case constant.ChannelTypeDeepSeek:
|
||||||
return updateChannelDeepSeekBalance(channel)
|
return updateChannelDeepSeekBalance(channel)
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return updateChannelOpenRouterBalance(channel)
|
return updateChannelOpenRouterBalance(channel)
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
return updateChannelMoonshotBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -31,15 +31,21 @@ import (
|
|||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
return errors.New("midjourney plus channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeSunoAPI {
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return errors.New("suno channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeKling {
|
||||||
|
return errors.New("kling channel test is not supported"), nil
|
||||||
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeJimeng {
|
||||||
|
return errors.New("jimeng channel test is not supported"), nil
|
||||||
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
@@ -50,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||||
strings.Contains(testModel, "embed") ||
|
strings.Contains(testModel, "embed") ||
|
||||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||||
requestPath = "/v1/embeddings" // 修改请求路径
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,13 +96,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
info := relaycommon.GenRelayInfo(c)
|
info := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info)
|
err = helper.ModelMappedHelper(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
@@ -165,10 +171,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
|
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -196,7 +202,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
testRequest.MaxTokens = 50
|
testRequest.MaxTokens = 50
|
||||||
}
|
}
|
||||||
} else if strings.Contains(model, "gemini") {
|
} else if strings.Contains(model, "gemini") {
|
||||||
testRequest.MaxTokens = 300
|
testRequest.MaxTokens = 3000
|
||||||
} else {
|
} else {
|
||||||
testRequest.MaxTokens = 10
|
testRequest.MaxTokens = 10
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -40,6 +41,17 @@ type OpenAIModelsResponse struct {
|
|||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseStatusFilter(statusParam string) int {
|
||||||
|
switch strings.ToLower(statusParam) {
|
||||||
|
case "enabled", "1":
|
||||||
|
return common.ChannelStatusEnabled
|
||||||
|
case "disabled", "0":
|
||||||
|
return 0
|
||||||
|
default:
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
p, _ := strconv.Atoi(c.Query("p"))
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||||
@@ -52,34 +64,89 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
channelData := make([]*model.Channel, 0)
|
channelData := make([]*model.Channel, 0)
|
||||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||||
|
statusParam := c.Query("status")
|
||||||
|
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||||
|
statusFilter := parseStatusFilter(statusParam)
|
||||||
|
// type filter
|
||||||
|
typeStr := c.Query("type")
|
||||||
|
typeFilter := -1
|
||||||
|
if typeStr != "" {
|
||||||
|
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||||
|
typeFilter = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
if enableTagMode {
|
if enableTagMode {
|
||||||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
|
||||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if tag != nil && *tag != "" {
|
if tag == nil || *tag == "" {
|
||||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
continue
|
||||||
if err == nil {
|
|
||||||
channelData = append(channelData, tagChannel...)
|
|
||||||
}
|
}
|
||||||
|
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
filtered := make([]*model.Channel, 0)
|
||||||
|
for _, ch := range tagChannels {
|
||||||
|
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, ch)
|
||||||
|
}
|
||||||
|
channelData = append(channelData, filtered...)
|
||||||
}
|
}
|
||||||
// 计算 tag 总数用于分页
|
|
||||||
total, _ = model.CountAllTags()
|
total, _ = model.CountAllTags()
|
||||||
} else {
|
} else {
|
||||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
baseQuery := model.DB.Model(&model.Channel{})
|
||||||
|
if typeFilter >= 0 {
|
||||||
|
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||||
|
}
|
||||||
|
if statusFilter == common.ChannelStatusEnabled {
|
||||||
|
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||||
|
} else if statusFilter == 0 {
|
||||||
|
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseQuery.Count(&total)
|
||||||
|
|
||||||
|
order := "priority desc"
|
||||||
|
if idSort {
|
||||||
|
order = "id desc"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channelData = channels
|
}
|
||||||
total, _ = model.CountAllChannels()
|
|
||||||
|
countQuery := model.DB.Model(&model.Channel{})
|
||||||
|
if statusFilter == common.ChannelStatusEnabled {
|
||||||
|
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||||
|
} else if statusFilter == 0 {
|
||||||
|
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||||
|
}
|
||||||
|
var results []struct {
|
||||||
|
Type int64
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||||
|
typeCounts := make(map[int64]int64)
|
||||||
|
for _, r := range results {
|
||||||
|
typeCounts[r.Type] = r.Count
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -90,6 +157,7 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
"total": total,
|
"total": total,
|
||||||
"page": p,
|
"page": p,
|
||||||
"page_size": pageSize,
|
"page_size": pageSize,
|
||||||
|
"type_counts": typeCounts,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -114,22 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//if channel.Type != common.ChannelTypeOpenAI {
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
// c.JSON(http.StatusOK, gin.H{
|
|
||||||
// "success": false,
|
|
||||||
// "message": "仅支持 OpenAI 类型渠道",
|
|
||||||
// })
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||||
case common.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||||
}
|
}
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
@@ -153,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
var ids []string
|
var ids []string
|
||||||
for _, model := range result.Data {
|
for _, model := range result.Data {
|
||||||
id := model.ID
|
id := model.ID
|
||||||
if channel.Type == common.ChannelTypeGemini {
|
if channel.Type == constant.ChannelTypeGemini {
|
||||||
id = strings.TrimPrefix(id, "models/")
|
id = strings.TrimPrefix(id, "models/")
|
||||||
}
|
}
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
@@ -186,6 +247,8 @@ func SearchChannels(c *gin.Context) {
|
|||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
modelKeyword := c.Query("model")
|
modelKeyword := c.Query("model")
|
||||||
|
statusParam := c.Query("status")
|
||||||
|
statusFilter := parseStatusFilter(statusParam)
|
||||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||||
channelData := make([]*model.Channel, 0)
|
channelData := make([]*model.Channel, 0)
|
||||||
@@ -217,10 +280,74 @@ func SearchChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channelData = channels
|
channelData = channels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||||||
|
filtered := make([]*model.Channel, 0, len(channelData))
|
||||||
|
for _, ch := range channelData {
|
||||||
|
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, ch)
|
||||||
|
}
|
||||||
|
channelData = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate type counts for search results
|
||||||
|
typeCounts := make(map[int64]int64)
|
||||||
|
for _, channel := range channelData {
|
||||||
|
typeCounts[int64(channel.Type)]++
|
||||||
|
}
|
||||||
|
|
||||||
|
typeParam := c.Query("type")
|
||||||
|
typeFilter := -1
|
||||||
|
if typeParam != "" {
|
||||||
|
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||||||
|
typeFilter = tp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeFilter >= 0 {
|
||||||
|
filtered := make([]*model.Channel, 0, len(channelData))
|
||||||
|
for _, ch := range channelData {
|
||||||
|
if ch.Type == typeFilter {
|
||||||
|
filtered = append(filtered, ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
channelData = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
total := len(channelData)
|
||||||
|
startIdx := (page - 1) * pageSize
|
||||||
|
if startIdx > total {
|
||||||
|
startIdx = total
|
||||||
|
}
|
||||||
|
endIdx := startIdx + pageSize
|
||||||
|
if endIdx > total {
|
||||||
|
endIdx = total
|
||||||
|
}
|
||||||
|
|
||||||
|
pagedData := channelData[startIdx:endIdx]
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": channelData,
|
"data": gin.H{
|
||||||
|
"items": pagedData,
|
||||||
|
"total": total,
|
||||||
|
"type_counts": typeCounts,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -283,7 +410,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi {
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
if addChannelRequest.Channel.Other == "" {
|
if addChannelRequest.Channel.Other == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -566,7 +693,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeVertexAi {
|
if channel.Type == constant.ChannelTypeVertexAi {
|
||||||
if channel.Other == "" {
|
if channel.Other == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -595,6 +722,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
channel.Key = ""
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -620,7 +748,7 @@ func FetchModels(c *gin.Context) {
|
|||||||
|
|
||||||
baseURL := req.BaseURL
|
baseURL := req.BaseURL
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = common.ChannelBaseURLs[req.Type]
|
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
|||||||
userGroup := ""
|
userGroup := ""
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, _ = model.GetUserGroup(userId, false)
|
userGroup, _ = model.GetUserGroup(userId, false)
|
||||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||||
// UserUsableGroups contains the groups that the user can use
|
// UserUsableGroups contains the groups that the user can use
|
||||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||||
if desc, ok := userUsableGroups[groupName]; ok {
|
if desc, ok := userUsableGroups[groupName]; ok {
|
||||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if setting.GroupInUserUsableGroups("auto") {
|
||||||
|
usableGroups["auto"] = map[string]interface{}{
|
||||||
|
"ratio": "自动",
|
||||||
|
"desc": setting.GetUsableGroupDescription("auto"),
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/console_setting"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"one-api/setting/console_setting"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -75,6 +75,8 @@ func GetStatus(c *gin.Context) {
|
|||||||
"chats": setting.Chats,
|
"chats": setting.Chats,
|
||||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||||
|
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||||
|
"pay_methods": setting.PayMethods,
|
||||||
|
|
||||||
// 面板启用开关
|
// 面板启用开关
|
||||||
"api_info_enabled": cs.ApiInfoEnabled,
|
"api_info_enabled": cs.ApiInfoEnabled,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -14,7 +15,7 @@ import (
|
|||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -23,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
|||||||
var openAIModelsMap map[string]dto.OpenAIModels
|
var openAIModelsMap map[string]dto.OpenAIModels
|
||||||
var channelId2Models map[int][]string
|
var channelId2Models map[int][]string
|
||||||
|
|
||||||
func getPermission() []dto.OpenAIModelPermission {
|
|
||||||
var permission []dto.OpenAIModelPermission
|
|
||||||
permission = append(permission, dto.OpenAIModelPermission{
|
|
||||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
|
||||||
Object: "model_permission",
|
|
||||||
Created: 1626777600,
|
|
||||||
AllowCreateEngine: true,
|
|
||||||
AllowSampling: true,
|
|
||||||
AllowLogprobs: true,
|
|
||||||
AllowSearchIndices: false,
|
|
||||||
AllowView: true,
|
|
||||||
AllowFineTuning: false,
|
|
||||||
Organization: "*",
|
|
||||||
Group: nil,
|
|
||||||
IsBlocking: false,
|
|
||||||
})
|
|
||||||
return permission
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
permission := getPermission()
|
for i := 0; i < constant.APITypeDummy; i++ {
|
||||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
if i == constant.APITypeAIProxyLibrary {
|
||||||
if i == relayconstant.APITypeAIProxyLibrary {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
adaptor := relay.GetAdaptor(i)
|
adaptor := relay.GetAdaptor(i)
|
||||||
@@ -58,9 +39,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: channelName,
|
OwnedBy: channelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,9 +48,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: ai360.ChannelName,
|
OwnedBy: ai360.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range moonshot.ModelList {
|
for _, modelName := range moonshot.ModelList {
|
||||||
@@ -81,9 +56,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: moonshot.ChannelName,
|
OwnedBy: moonshot.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range lingyiwanwu.ModelList {
|
for _, modelName := range lingyiwanwu.ModelList {
|
||||||
@@ -92,9 +64,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: lingyiwanwu.ChannelName,
|
OwnedBy: lingyiwanwu.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range minimax.ModelList {
|
for _, modelName := range minimax.ModelList {
|
||||||
@@ -103,9 +72,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: minimax.ChannelName,
|
OwnedBy: minimax.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
@@ -114,9 +80,6 @@ func init() {
|
|||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "midjourney",
|
OwnedBy: "midjourney",
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||||
@@ -124,9 +87,9 @@ func init() {
|
|||||||
openAIModelsMap[aiModel.Id] = aiModel
|
openAIModelsMap[aiModel.Id] = aiModel
|
||||||
}
|
}
|
||||||
channelId2Models = make(map[int][]string)
|
channelId2Models = make(map[int][]string)
|
||||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
apiType, success := common.ChannelType2APIType(i)
|
||||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||||
@@ -134,15 +97,17 @@ func init() {
|
|||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
}
|
}
|
||||||
|
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||||
|
return m.Id
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||||
permission := getPermission()
|
|
||||||
|
|
||||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := c.Get("token_model_limit")
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
var tokenModelLimit map[string]bool
|
||||||
if ok {
|
if ok {
|
||||||
tokenModelLimit = s.(map[string]bool)
|
tokenModelLimit = s.(map[string]bool)
|
||||||
@@ -150,23 +115,22 @@ func ListModels(c *gin.Context) {
|
|||||||
tokenModelLimit = map[string]bool{}
|
tokenModelLimit = map[string]bool{}
|
||||||
}
|
}
|
||||||
for allowModel, _ := range tokenModelLimit {
|
for allowModel, _ := range tokenModelLimit {
|
||||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: allowModel,
|
Id: allowModel,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||||
Root: allowModel,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, err := model.GetUserGroup(userId, true)
|
userGroup, err := model.GetUserGroup(userId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -175,23 +139,34 @@ func ListModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
group := userGroup
|
group := userGroup
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
group = tokenGroup
|
group = tokenGroup
|
||||||
}
|
}
|
||||||
models := model.GetGroupModels(group)
|
var models []string
|
||||||
for _, s := range models {
|
if tokenGroup == "auto" {
|
||||||
if _, ok := openAIModelsMap[s]; ok {
|
for _, autoGroup := range setting.AutoGroups {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||||
|
for _, g := range groupModels {
|
||||||
|
if !common.StringsContains(models, g) {
|
||||||
|
models = append(models, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
models = model.GetGroupEnabledModels(group)
|
||||||
|
}
|
||||||
|
for _, modelName := range models {
|
||||||
|
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||||
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: s,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||||
Root: s,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/console_setting"
|
"one-api/setting/console_setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = setting.CheckGroupRatio(option.Value)
|
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -13,6 +12,8 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Playground(c *gin.Context) {
|
func Playground(c *gin.Context) {
|
||||||
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
|
|||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
}
|
}
|
||||||
c.Set("token_name", "playground-"+group)
|
c.Set("token_name", "playground-"+group)
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
userCache, err := model.GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userCache.WriteContext(c)
|
||||||
Relay(c)
|
Relay(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
userId, exists := c.Get("id")
|
userId, exists := c.Get("id")
|
||||||
usableGroup := map[string]string{}
|
usableGroup := map[string]string{}
|
||||||
groupRatio := map[string]float64{}
|
groupRatio := map[string]float64{}
|
||||||
for s, f := range setting.GetGroupRatioCopy() {
|
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||||
groupRatio[s] = f
|
groupRatio[s] = f
|
||||||
}
|
}
|
||||||
var group string
|
var group string
|
||||||
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
group = user.Group
|
group = user.Group
|
||||||
for g := range groupRatio {
|
for g := range groupRatio {
|
||||||
ratio, ok := setting.GetGroupGroupRatio(group, g)
|
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||||
if ok {
|
if ok {
|
||||||
groupRatio[g] = ratio
|
groupRatio[g] = ratio
|
||||||
}
|
}
|
||||||
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
|
|
||||||
usableGroup = setting.GetUserUsableGroups(group)
|
usableGroup = setting.GetUserUsableGroups(group)
|
||||||
// check groupRatio contains usableGroup
|
// check groupRatio contains usableGroup
|
||||||
for group := range setting.GetGroupRatioCopy() {
|
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||||
if _, ok := usableGroup[group]; !ok {
|
if _, ok := usableGroup[group]; !ok {
|
||||||
delete(groupRatio, group)
|
delete(groupRatio, group)
|
||||||
}
|
}
|
||||||
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ResetModelRatio(c *gin.Context) {
|
func ResetModelRatio(c *gin.Context) {
|
||||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
|||||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetRatioConfig(c *gin.Context) {
|
||||||
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "倍率配置接口未启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": ratio_setting.GetExposedData(),
|
||||||
|
})
|
||||||
|
}
|
||||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeoutSeconds = 10
|
||||||
|
defaultEndpoint = "/api/ratio_config"
|
||||||
|
maxConcurrentFetches = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
|
type upstreamResult struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data map[string]any `json:"data,omitempty"`
|
||||||
|
Err string `json:"err,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
|
var req dto.UpstreamRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Timeout <= 0 {
|
||||||
|
req.Timeout = defaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreams []dto.UpstreamDTO
|
||||||
|
|
||||||
|
if len(req.Upstreams) > 0 {
|
||||||
|
for _, u := range req.Upstreams {
|
||||||
|
if strings.HasPrefix(u.BaseURL, "http") {
|
||||||
|
if u.Endpoint == "" {
|
||||||
|
u.Endpoint = defaultEndpoint
|
||||||
|
}
|
||||||
|
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||||
|
upstreams = append(upstreams, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if len(req.ChannelIDs) > 0 {
|
||||||
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||||
|
for _, id64 := range req.ChannelIDs {
|
||||||
|
intIds = append(intIds, int(id64))
|
||||||
|
}
|
||||||
|
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, ch := range dbChannels {
|
||||||
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||||
|
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||||
|
ID: ch.Id,
|
||||||
|
Name: ch.Name,
|
||||||
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
|
Endpoint: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
ch := make(chan upstreamResult, len(upstreams))
|
||||||
|
|
||||||
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
|
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||||
|
|
||||||
|
for _, chn := range upstreams {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(chItem dto.UpstreamDTO) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
|
||||||
|
endpoint := chItem.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL := chItem.BaseURL + endpoint
|
||||||
|
|
||||||
|
uniqueName := chItem.Name
|
||||||
|
if chItem.ID != 0 {
|
||||||
|
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 兼容两种上游接口格式:
|
||||||
|
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||||
|
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||||
|
var body struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !body.Success {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试按 type1 解析
|
||||||
|
var type1Data map[string]any
|
||||||
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||||
|
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||||
|
isType1 := false
|
||||||
|
for _, rt := range ratioTypes {
|
||||||
|
if _, ok := type1Data[rt]; ok {
|
||||||
|
isType1 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isType1 {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||||
|
var pricingItems []struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
QuotaType int `json:"quota_type"`
|
||||||
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
|
ModelPrice float64 `json:"model_price"`
|
||||||
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelRatioMap := make(map[string]float64)
|
||||||
|
completionRatioMap := make(map[string]float64)
|
||||||
|
modelPriceMap := make(map[string]float64)
|
||||||
|
|
||||||
|
for _, item := range pricingItems {
|
||||||
|
if item.QuotaType == 1 {
|
||||||
|
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||||
|
} else {
|
||||||
|
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||||
|
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||||
|
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
converted := make(map[string]any)
|
||||||
|
|
||||||
|
if len(modelRatioMap) > 0 {
|
||||||
|
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||||
|
for k, v := range modelRatioMap {
|
||||||
|
ratioAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_ratio"] = ratioAny
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(completionRatioMap) > 0 {
|
||||||
|
compAny := make(map[string]any, len(completionRatioMap))
|
||||||
|
for k, v := range completionRatioMap {
|
||||||
|
compAny[k] = v
|
||||||
|
}
|
||||||
|
converted["completion_ratio"] = compAny
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelPriceMap) > 0 {
|
||||||
|
priceAny := make(map[string]any, len(modelPriceMap))
|
||||||
|
for k, v := range modelPriceMap {
|
||||||
|
priceAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_price"] = priceAny
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||||
|
}(chn)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(ch)
|
||||||
|
|
||||||
|
localData := ratio_setting.GetExposedData()
|
||||||
|
|
||||||
|
var testResults []dto.TestResult
|
||||||
|
var successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
for r := range ch {
|
||||||
|
if r.Err != "" {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "error",
|
||||||
|
Error: r.Err,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "success",
|
||||||
|
})
|
||||||
|
successfulChannels = append(successfulChannels, struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}{name: r.Name, data: r.Data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
differences := buildDifferences(localData, successfulChannels)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"differences": differences,
|
||||||
|
"test_results": testResults,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}) map[string]map[string]dto.DifferenceItem {
|
||||||
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||||
|
|
||||||
|
allModels := make(map[string]struct{})
|
||||||
|
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
for modelName := range localRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
for modelName := range upstreamRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
confidenceMap := make(map[string]map[string]bool)
|
||||||
|
|
||||||
|
// 预处理阶段:检查pricing接口的可信度
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
confidenceMap[channel.name] = make(map[string]bool)
|
||||||
|
|
||||||
|
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||||
|
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||||
|
|
||||||
|
if hasModelRatio && hasCompletionRatio {
|
||||||
|
// 遍历所有模型,检查是否满足不可信条件
|
||||||
|
for modelName := range allModels {
|
||||||
|
// 默认为可信
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
|
||||||
|
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||||
|
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||||
|
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||||
|
// 转换为float64进行比较
|
||||||
|
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||||
|
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||||
|
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||||
|
confidenceMap[channel.name][modelName] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||||
|
for modelName := range allModels {
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName := range allModels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
var localValue interface{} = nil
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
if val, exists := localRatio[modelName]; exists {
|
||||||
|
localValue = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues := make(map[string]interface{})
|
||||||
|
confidenceValues := make(map[string]bool)
|
||||||
|
hasUpstreamValue := false
|
||||||
|
hasDifference := false
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
var upstreamValue interface{} = nil
|
||||||
|
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
if val, exists := upstreamRatio[modelName]; exists {
|
||||||
|
upstreamValue = val
|
||||||
|
hasUpstreamValue = true
|
||||||
|
|
||||||
|
if localValue != nil && localValue != val {
|
||||||
|
hasDifference = true
|
||||||
|
} else if localValue == val {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamValue == nil && localValue == nil {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
|
||||||
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||||
|
hasDifference = true
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues[channel.name] = upstreamValue
|
||||||
|
|
||||||
|
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldInclude := false
|
||||||
|
|
||||||
|
if localValue != nil {
|
||||||
|
if hasDifference {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if hasUpstreamValue {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInclude {
|
||||||
|
if differences[modelName] == nil {
|
||||||
|
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||||
|
Current: localValue,
|
||||||
|
Upstreams: upstreamValues,
|
||||||
|
Confidence: confidenceValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelHasDiff := make(map[string]bool)
|
||||||
|
for _, ratioMap := range differences {
|
||||||
|
for _, item := range ratioMap {
|
||||||
|
for chName, val := range item.Upstreams {
|
||||||
|
if val != nil && val != "same" {
|
||||||
|
channelHasDiff[chName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName, ratioMap := range differences {
|
||||||
|
for ratioType, item := range ratioMap {
|
||||||
|
for chName := range item.Upstreams {
|
||||||
|
if !channelHasDiff[chName] {
|
||||||
|
delete(item.Upstreams, chName)
|
||||||
|
delete(item.Confidence, chName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allSame := true
|
||||||
|
for _, v := range item.Upstreams {
|
||||||
|
if v != "same" {
|
||||||
|
allSame = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(item.Upstreams) == 0 || allSame {
|
||||||
|
delete(ratioMap, ratioType)
|
||||||
|
} else {
|
||||||
|
differences[modelName][ratioType] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ratioMap) == 0 {
|
||||||
|
delete(differences, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return differences
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSyncableChannels(c *gin.Context) {
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncableChannels []dto.SyncableChannel
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
|
ID: channel.Id,
|
||||||
|
Name: channel.Name,
|
||||||
|
BaseURL: channel.GetBaseURL(),
|
||||||
|
Status: channel.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": syncableChannels,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
constant2 "one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/constant"
|
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||||
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
AutoBan: &autoBanInt,
|
AutoBan: &autoBanInt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||||
}
|
}
|
||||||
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
|||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
channelType := c.GetInt("channel_type")
|
channelType := c.GetInt("channel_type")
|
||||||
if channelType == common.ChannelTypeAnthropic {
|
if channelType == constant.ChannelTypeAnthropic {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
break
|
break
|
||||||
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayMode)
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
|
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||||
|
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
common.SysLog("未知平台")
|
||||||
}
|
}
|
||||||
|
|||||||
138
controller/task_video.go
Normal file
138
controller/task_video.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||||
|
if err != nil {
|
||||||
|
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if errUpdate != nil {
|
||||||
|
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
|
}
|
||||||
|
adaptor := relay.GetTaskAdaptor(platform)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("video adaptor not found")
|
||||||
|
}
|
||||||
|
for _, taskId := range taskIds {
|
||||||
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||||
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
baseURL = channel.GetBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
task := taskM[taskId]
|
||||||
|
if task == nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
|
}
|
||||||
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
|
"task_id": taskId,
|
||||||
|
"action": task.Action,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
//if resp.StatusCode != http.StatusOK {
|
||||||
|
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||||
|
//}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
//if taskResult.Code != 0 {
|
||||||
|
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||||
|
//}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if taskResult.Status == "" {
|
||||||
|
return fmt.Errorf("task %s status is empty", taskId)
|
||||||
|
}
|
||||||
|
task.Status = model.TaskStatus(taskResult.Status)
|
||||||
|
switch taskResult.Status {
|
||||||
|
case model.TaskStatusSubmitted:
|
||||||
|
task.Progress = "10%"
|
||||||
|
case model.TaskStatusQueued:
|
||||||
|
task.Progress = "20%"
|
||||||
|
case model.TaskStatusInProgress:
|
||||||
|
task.Progress = "30%"
|
||||||
|
if task.StartTime == 0 {
|
||||||
|
task.StartTime = now
|
||||||
|
}
|
||||||
|
case model.TaskStatusSuccess:
|
||||||
|
task.Progress = "100%"
|
||||||
|
if task.FinishTime == 0 {
|
||||||
|
task.FinishTime = now
|
||||||
|
}
|
||||||
|
task.FailReason = taskResult.Url
|
||||||
|
case model.TaskStatusFailure:
|
||||||
|
task.Status = model.TaskStatusFailure
|
||||||
|
task.Progress = "100%"
|
||||||
|
if task.FinishTime == 0 {
|
||||||
|
task.FinishTime = now
|
||||||
|
}
|
||||||
|
task.FailReason = taskResult.Reason
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
|
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||||
|
}
|
||||||
|
if taskResult.Progress != "" {
|
||||||
|
task.Progress = taskResult.Progress
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Data = responseBody
|
||||||
|
if err := task.Update(); err != nil {
|
||||||
|
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -258,3 +258,32 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenBatch struct {
|
||||||
|
Ids []int `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteTokenBatch(c *gin.Context) {
|
||||||
|
tokenBatch := TokenBatch{}
|
||||||
|
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payType := "wxpay"
|
|
||||||
if req.PaymentMethod == "zfb" {
|
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||||
payType = "alipay"
|
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||||
}
|
return
|
||||||
if req.PaymentMethod == "wx" {
|
|
||||||
req.PaymentMethod = "wxpay"
|
|
||||||
payType = "wxpay"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
callBackAddress := service.GetCallbackAddress()
|
||||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||||
Type: payType,
|
Type: req.PaymentMethod,
|
||||||
ServiceTradeNo: tradeNo,
|
ServiceTradeNo: tradeNo,
|
||||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||||
|
|||||||
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
|
|||||||
UnlimitedQuota: true,
|
UnlimitedQuota: true,
|
||||||
ModelLimitsEnabled: false,
|
ModelLimitsEnabled: false,
|
||||||
}
|
}
|
||||||
|
if setting.DefaultUseAutoGroup {
|
||||||
|
token.Group = "auto"
|
||||||
|
}
|
||||||
if err := token.Insert(); err != nil {
|
if err := token.Insert(); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -243,15 +246,15 @@ func Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(c *gin.Context) {
|
func GetAllUsers(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo, err := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
if err != nil {
|
||||||
if p < 1 {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
p = 1
|
"success": false,
|
||||||
|
"message": "parse page query failed",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if pageSize < 0 {
|
users, total, err := model.GetAllUsers(pageInfo)
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -259,15 +262,13 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(users)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": pageInfo,
|
||||||
"items": users,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -459,6 +460,9 @@ func GetSelf(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||||
|
user.Remark = ""
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -483,7 +487,7 @@ func GetUserModels(c *gin.Context) {
|
|||||||
groups := setting.GetUserUsableGroups(user.Group)
|
groups := setting.GetUserUsableGroups(user.Group)
|
||||||
var models []string
|
var models []string
|
||||||
for group := range groups {
|
for group := range groups {
|
||||||
for _, g := range model.GetGroupModels(group) {
|
for _, g := range model.GetGroupEnabledModels(group) {
|
||||||
if !common.StringsContains(models, g) {
|
if !common.StringsContains(models, g) {
|
||||||
models = append(models, g)
|
models = append(models, g)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
version: '3.4' # 兼容旧版docker-compose
|
version: '3.4'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
new-api:
|
new-api:
|
||||||
@@ -16,6 +16,7 @@ services:
|
|||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||||
|
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
|
|||||||
@@ -178,7 +178,14 @@ type ClaudeRequest struct {
|
|||||||
|
|
||||||
type Thinking struct {
|
type Thinking struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
BudgetTokens int `json:"budget_tokens"`
|
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Thinking) GetBudgetTokens() int {
|
||||||
|
if c.BudgetTokens == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *c.BudgetTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type ImageRequest struct {
|
|||||||
Background string `json:"background,omitempty"`
|
Background string `json:"background,omitempty"`
|
||||||
Moderation string `json:"moderation,omitempty"`
|
Moderation string `json:"moderation,omitempty"`
|
||||||
OutputFormat string `json:"output_format,omitempty"`
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
|
|||||||
StartTime int64 `json:"startTime"`
|
StartTime int64 `json:"startTime"`
|
||||||
FinishTime int64 `json:"finishTime"`
|
FinishTime int64 `json:"finishTime"`
|
||||||
ImageUrl string `json:"imageUrl"`
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
VideoUrl string `json:"videoUrl"`
|
||||||
|
VideoUrls []ImgUrls `json:"videoUrls"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Progress string `json:"progress"`
|
Progress string `json:"progress"`
|
||||||
FailReason string `json:"failReason"`
|
FailReason string `json:"failReason"`
|
||||||
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
|
|||||||
Properties *Properties `json:"properties"`
|
Properties *Properties `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImgUrls struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyStatus struct {
|
type MidjourneyStatus struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,9 +53,11 @@ type GeneralOpenAIRequest struct {
|
|||||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||||
Audio json.RawMessage `json:"audio,omitempty"`
|
Audio json.RawMessage `json:"audio,omitempty"`
|
||||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||||
|
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||||
// OpenRouter Params
|
// OpenRouter Params
|
||||||
|
Usage json.RawMessage `json:"usage,omitempty"`
|
||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
// Ali Qwen Params
|
// Ali Qwen Params
|
||||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
@@ -64,7 +66,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
data, _ := common.EncodeJson(r)
|
data, _ := common.EncodeJson(r)
|
||||||
_ = common.DecodeJson(data, &result)
|
_ = common.UnmarshalJson(data, &result)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -644,4 +646,6 @@ type ResponsesToolsCall struct {
|
|||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Function json.RawMessage `json:"function,omitempty"`
|
||||||
|
Container json.RawMessage `json:"container,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
|
|||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int64 `json:"created"`
|
Created any `json:"created"`
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *OpenAIError `json:"error,omitempty"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
@@ -178,6 +178,8 @@ type Usage struct {
|
|||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||||
|
// OpenRouter Params
|
||||||
|
Cost float64 `json:"cost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
|
|||||||
@@ -1,26 +1,11 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type OpenAIModelPermission struct {
|
import "one-api/constant"
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
||||||
AllowSampling bool `json:"allow_sampling"`
|
|
||||||
AllowLogprobs bool `json:"allow_logprobs"`
|
|
||||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
||||||
AllowView bool `json:"allow_view"`
|
|
||||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
||||||
Organization string `json:"organization"`
|
|
||||||
Group *string `json:"group"`
|
|
||||||
IsBlocking bool `json:"is_blocking"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int `json:"created"`
|
Created int `json:"created"`
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
Root string `json:"root"`
|
|
||||||
Parent *string `json:"parent"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
38
dto/ratio_sync.go
Normal file
38
dto/ratio_sync.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type UpstreamDTO struct {
|
||||||
|
ID int `json:"id,omitempty"`
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpstreamRequest struct {
|
||||||
|
ChannelIDs []int64 `json:"channel_ids"`
|
||||||
|
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||||
|
Timeout int `json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResult 上游测试连通性结果
|
||||||
|
type TestResult struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DifferenceItem 差异项
|
||||||
|
// Current 为本地值,可能为 nil
|
||||||
|
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||||
|
|
||||||
|
type DifferenceItem struct {
|
||||||
|
Current interface{} `json:"current"`
|
||||||
|
Upstreams map[string]interface{} `json:"upstreams"`
|
||||||
|
Confidence map[string]bool `json:"confidence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncableChannel struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ type RerankRequest struct {
|
|||||||
Documents []any `json:"documents"`
|
Documents []any `json:"documents"`
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
TopN int `json:"top_n"`
|
TopN int `json:"top_n,omitempty"`
|
||||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||||
|
|||||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type VideoRequest struct {
|
||||||
|
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||||
|
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||||
|
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||||
|
Width int `json:"width" example:"512"` // Video width
|
||||||
|
Height int `json:"height" example:"512"` // Video height
|
||||||
|
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||||
|
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||||
|
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||||
|
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||||
|
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoResponse 视频生成提交任务后的响应
|
||||||
|
type VideoResponse struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||||
|
type VideoTaskResponse struct {
|
||||||
|
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||||
|
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||||
|
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||||
|
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||||
|
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||||
|
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskMetadata 视频任务元数据
|
||||||
|
type VideoTaskMetadata struct {
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||||
|
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||||
|
Width int `json:"width" example:"512"` // 实际宽度
|
||||||
|
Height int `json:"height" example:"512"` // 实际高度
|
||||||
|
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskError 视频任务错误信息
|
||||||
|
type VideoTaskError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
1041
i18n/zh-cn.json
Normal file
1041
i18n/zh-cn.json
Normal file
File diff suppressed because it is too large
Load Diff
91
main.go
91
main.go
@@ -12,7 +12,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/ratio_setting"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -32,12 +32,12 @@ var buildFS embed.FS
|
|||||||
var indexPage []byte
|
var indexPage []byte
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
err := godotenv.Load(".env")
|
|
||||||
if err != nil {
|
|
||||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
common.LoadEnv()
|
err := InitResources()
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
common.SetupLogger()
|
common.SetupLogger()
|
||||||
common.SysLog("New API " + common.Version + " started")
|
common.SysLog("New API " + common.Version + " started")
|
||||||
@@ -47,19 +47,7 @@ func main() {
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
common.SysLog("running in debug mode")
|
common.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
// Initialize SQL Database
|
|
||||||
err = model.InitDB()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
model.CheckSetup()
|
|
||||||
|
|
||||||
// Initialize SQL Database
|
|
||||||
err = model.InitLogDB()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := model.CloseDB()
|
err := model.CloseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -67,21 +55,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Initialize Redis
|
|
||||||
err = common.InitRedisClient()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize model settings
|
|
||||||
operation_setting.InitRatioSettings()
|
|
||||||
// Initialize constants
|
|
||||||
constant.InitEnv()
|
|
||||||
// Initialize options
|
|
||||||
model.InitOptionMap()
|
|
||||||
|
|
||||||
service.InitTokenEncoders()
|
|
||||||
|
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
// for compatibility with old versions
|
// for compatibility with old versions
|
||||||
common.MemoryCacheEnabled = true
|
common.MemoryCacheEnabled = true
|
||||||
@@ -105,10 +78,12 @@ func main() {
|
|||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go model.SyncOptions(common.SyncFrequency)
|
|
||||||
go model.SyncChannelCache(common.SyncFrequency)
|
go model.SyncChannelCache(common.SyncFrequency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 热更新配置
|
||||||
|
go model.SyncOptions(common.SyncFrequency)
|
||||||
|
|
||||||
// 数据看板
|
// 数据看板
|
||||||
go model.UpdateQuotaData()
|
go model.UpdateQuotaData()
|
||||||
|
|
||||||
@@ -184,3 +159,51 @@ func main() {
|
|||||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InitResources() error {
|
||||||
|
// Initialize resources here if needed
|
||||||
|
// This is a placeholder function for future resource initialization
|
||||||
|
err := godotenv.Load(".env")
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||||
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载环境变量
|
||||||
|
common.InitEnv()
|
||||||
|
|
||||||
|
// Initialize model settings
|
||||||
|
ratio_setting.InitRatioSettings()
|
||||||
|
|
||||||
|
service.InitHttpClient()
|
||||||
|
|
||||||
|
service.InitTokenEncoders()
|
||||||
|
|
||||||
|
// Initialize SQL Database
|
||||||
|
err = model.InitDB()
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to initialize database: " + err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
model.CheckSetup()
|
||||||
|
|
||||||
|
// Initialize options, should after model.InitDB()
|
||||||
|
model.InitOptionMap()
|
||||||
|
|
||||||
|
// 初始化模型
|
||||||
|
model.GetPricing()
|
||||||
|
|
||||||
|
// Initialize SQL Database
|
||||||
|
err = model.InitLogDB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Redis
|
||||||
|
err = common.InitRedisClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// gemini api 从query中获取key
|
// gemini api 从query中获取key
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
skKey := c.Query("key")
|
skKey := c.Query("key")
|
||||||
if skKey != "" {
|
if skKey != "" {
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -24,7 +25,7 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
allowIpsMap := c.GetStringMap("allow_ips")
|
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||||
if len(allowIpsMap) != 0 {
|
if len(allowIpsMap) != 0 {
|
||||||
clientIp := c.ClientIP()
|
clientIp := c.ClientIP()
|
||||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||||
@@ -33,14 +34,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("specific_channel_id")
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
// check common.UserUsableGroups[userGroup]
|
// check common.UserUsableGroups[userGroup]
|
||||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
@@ -48,13 +49,15 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// check group in common.GroupRatio
|
// check group in common.GroupRatio
|
||||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||||
|
if tokenGroup != "auto" {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
userGroup = tokenGroup
|
userGroup = tokenGroup
|
||||||
}
|
}
|
||||||
c.Set("group", userGroup)
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -73,9 +76,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
// check token model mapping
|
// check token model mapping
|
||||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := c.Get("token_model_limit")
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
var tokenModelLimit map[string]bool
|
||||||
if ok {
|
if ok {
|
||||||
tokenModelLimit = s.(map[string]bool)
|
tokenModelLimit = s.(map[string]bool)
|
||||||
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldSelectChannel {
|
if shouldSelectChannel {
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
var selectGroup string
|
||||||
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
showGroup := userGroup
|
||||||
|
if userGroup == "auto" {
|
||||||
|
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||||
|
}
|
||||||
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
@@ -113,7 +121,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
@@ -162,7 +170,26 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
var platform string
|
||||||
|
var relayMode int
|
||||||
|
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
||||||
|
platform = string(constant.TaskPlatformJimeng)
|
||||||
|
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
||||||
|
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
platform = string(constant.TaskPlatformKling)
|
||||||
|
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||||
|
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Set("platform", platform)
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
relayMode := relayconstant.RelayModeGemini
|
relayMode := relayconstant.RelayModeGemini
|
||||||
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||||
@@ -234,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeVertexAi:
|
case constant.ChannelTypeVertexAi:
|
||||||
c.Set("region", channel.Other)
|
c.Set("region", channel.Other)
|
||||||
case common.ChannelTypeXunfei:
|
case constant.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
case common.ChannelCloudflare:
|
case constant.ChannelCloudflare:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeMokaAI:
|
case constant.ChannelTypeMokaAI:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeCoze:
|
case constant.ChannelTypeCoze:
|
||||||
c.Set("bot_id", channel.Other)
|
c.Set("bot_id", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
47
middleware/kling_adapter.go
Normal file
47
middleware/kling_adapter.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func KlingRequestConvert() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var originalReq map[string]interface{}
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model, _ := originalReq["model"].(string)
|
||||||
|
prompt, _ := originalReq["prompt"].(string)
|
||||||
|
|
||||||
|
unifiedReq := map[string]interface{}{
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"metadata": originalReq,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(unifiedReq)
|
||||||
|
if err != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite request body and path
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||||
|
c.Request.URL.Path = "/v1/video/generations"
|
||||||
|
if image := originalReq["image"]; image == "" {
|
||||||
|
c.Set("action", constant.TaskActionTextGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have to reset the request body for the next handlers
|
||||||
|
c.Set(common.KeyRequestBody, jsonData)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
|||||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||||
|
|
||||||
// 获取分组
|
// 获取分组
|
||||||
group := c.GetString("token_group")
|
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if group == "" {
|
if group == "" {
|
||||||
group = c.GetString(constant.ContextKeyUserGroup)
|
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
//获取分组的限流配置
|
//获取分组的限流配置
|
||||||
|
|||||||
@@ -21,7 +21,22 @@ type Ability struct {
|
|||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGroupModels(group string) []string {
|
type AbilityWithChannel struct {
|
||||||
|
Ability
|
||||||
|
ChannelType int `json:"channel_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||||
|
var abilities []AbilityWithChannel
|
||||||
|
err := DB.Table("abilities").
|
||||||
|
Select("abilities.*, channels.type as channel_type").
|
||||||
|
Joins("left join channels on abilities.channel_id = channels.id").
|
||||||
|
Where("abilities.enabled = ?", true).
|
||||||
|
Scan(&abilities).Error
|
||||||
|
return abilities, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGroupEnabledModels(group string) []string {
|
||||||
var models []string
|
var models []string
|
||||||
// Find distinct models
|
// Find distinct models
|
||||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||||
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
var priorities []int
|
var priorities []int
|
||||||
err := DB.Model(&Ability{}).
|
err := DB.Model(&Ability{}).
|
||||||
Select("DISTINCT(priority)").
|
Select("DISTINCT(priority)").
|
||||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||||
Order("priority DESC"). // 按优先级降序排序
|
Order("priority DESC"). // 按优先级降序排序
|
||||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||||
|
|
||||||
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||||
if retry != 0 {
|
if retry != 0 {
|
||||||
priority, err := getPriority(group, model, retry)
|
priority, err := getPriority(group, model, retry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/setting"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||||
|
var channel *Channel
|
||||||
|
var err error
|
||||||
|
selectGroup := group
|
||||||
|
if group == "auto" {
|
||||||
|
if len(setting.AutoGroups) == 0 {
|
||||||
|
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||||
|
}
|
||||||
|
for _, autoGroup := range setting.AutoGroups {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println("autoGroup:", autoGroup)
|
||||||
|
}
|
||||||
|
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||||
|
if channel == nil {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Set("auto_group", autoGroup)
|
||||||
|
selectGroup = autoGroup
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println("selectGroup:", selectGroup)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, group, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel == nil {
|
||||||
|
return nil, group, errors.New("channel not found")
|
||||||
|
}
|
||||||
|
return channel, selectGroup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||||
model = "gpt-4-gizmo-*"
|
model = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -617,3 +617,39 @@ func CountAllTags() (int64, error) {
|
|||||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||||
return total, err
|
return total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get channels of specified type with pagination
|
||||||
|
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||||
|
var channels []*Channel
|
||||||
|
order := "priority desc"
|
||||||
|
if idSort {
|
||||||
|
order = "id desc"
|
||||||
|
}
|
||||||
|
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||||
|
return channels, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count channels of specific type
|
||||||
|
func CountChannelsByType(channelType int) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return map[type]count for all channels
|
||||||
|
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||||
|
type result struct {
|
||||||
|
Type int64 `gorm:"column:type"`
|
||||||
|
Count int64 `gorm:"column:count"`
|
||||||
|
}
|
||||||
|
var results []result
|
||||||
|
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts := make(map[int64]int64)
|
||||||
|
for _, r := range results {
|
||||||
|
counts[r.Type] = r.Count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,6 +46,15 @@ func initCol() {
|
|||||||
logGroupCol = commonGroupCol
|
logGroupCol = commonGroupCol
|
||||||
logKeyCol = commonKeyCol
|
logKeyCol = commonKeyCol
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
logGroupCol = `"group"`
|
||||||
|
logKeyCol = `"key"`
|
||||||
|
} else {
|
||||||
|
logGroupCol = commonGroupCol
|
||||||
|
logKeyCol = commonKeyCol
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// log sql type and database type
|
// log sql type and database type
|
||||||
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ type Midjourney struct {
|
|||||||
StartTime int64 `json:"start_time" gorm:"index"`
|
StartTime int64 `json:"start_time" gorm:"index"`
|
||||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||||
ImageUrl string `json:"image_url"`
|
ImageUrl string `json:"image_url"`
|
||||||
|
VideoUrl string `json:"video_url"`
|
||||||
|
VideoUrls string `json:"video_urls"`
|
||||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||||
FailReason string `json:"fail_reason"`
|
FailReason string `json:"fail_reason"`
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/config"
|
"one-api/setting/config"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -76,6 +77,9 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||||
|
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||||
|
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||||
|
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||||
common.OptionMap["GitHubClientId"] = ""
|
common.OptionMap["GitHubClientId"] = ""
|
||||||
common.OptionMap["GitHubClientSecret"] = ""
|
common.OptionMap["GitHubClientSecret"] = ""
|
||||||
common.OptionMap["TelegramBotToken"] = ""
|
common.OptionMap["TelegramBotToken"] = ""
|
||||||
@@ -94,13 +98,13 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||||
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
|
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||||
@@ -123,6 +127,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||||
|
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||||
|
|
||||||
// 自动添加所有注册的模型配置
|
// 自动添加所有注册的模型配置
|
||||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||||
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.ImageDownloadPermission = intValue
|
common.ImageDownloadPermission = intValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
|
||||||
boolValue := value == "true"
|
boolValue := value == "true"
|
||||||
switch key {
|
switch key {
|
||||||
case "PasswordRegisterEnabled":
|
case "PasswordRegisterEnabled":
|
||||||
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.SMTPSSLEnabled = boolValue
|
common.SMTPSSLEnabled = boolValue
|
||||||
case "WorkerAllowHttpImageRequestEnabled":
|
case "WorkerAllowHttpImageRequestEnabled":
|
||||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||||
|
case "DefaultUseAutoGroup":
|
||||||
|
setting.DefaultUseAutoGroup = boolValue
|
||||||
|
case "ExposeRatioEnabled":
|
||||||
|
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch key {
|
switch key {
|
||||||
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
setting.PayAddress = value
|
setting.PayAddress = value
|
||||||
case "Chats":
|
case "Chats":
|
||||||
err = setting.UpdateChatsByJsonString(value)
|
err = setting.UpdateChatsByJsonString(value)
|
||||||
|
case "AutoGroups":
|
||||||
|
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||||
case "CustomCallbackAddress":
|
case "CustomCallbackAddress":
|
||||||
setting.CustomCallbackAddress = value
|
setting.CustomCallbackAddress = value
|
||||||
case "EpayId":
|
case "EpayId":
|
||||||
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "DataExportDefaultTime":
|
case "DataExportDefaultTime":
|
||||||
common.DataExportDefaultTime = value
|
common.DataExportDefaultTime = value
|
||||||
case "ModelRatio":
|
case "ModelRatio":
|
||||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = setting.UpdateGroupRatioByJSONString(value)
|
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||||
case "GroupGroupRatio":
|
case "GroupGroupRatio":
|
||||||
err = setting.UpdateGroupGroupRatioByJSONString(value)
|
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||||
case "UserUsableGroups":
|
case "UserUsableGroups":
|
||||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||||
case "CompletionRatio":
|
case "CompletionRatio":
|
||||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||||
case "ModelPrice":
|
case "ModelPrice":
|
||||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||||
case "CacheRatio":
|
case "CacheRatio":
|
||||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||||
case "TopUpLink":
|
case "TopUpLink":
|
||||||
common.TopUpLink = value
|
common.TopUpLink = value
|
||||||
//case "ChatLink":
|
//case "ChatLink":
|
||||||
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||||
case "StreamCacheQueueLength":
|
case "StreamCacheQueueLength":
|
||||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||||
|
case "PayMethods":
|
||||||
|
err = setting.UpdatePayMethodsByJsonString(value)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
100
model/pricing.go
100
model/pricing.go
@@ -1,8 +1,11 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/constant"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/types"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -14,7 +17,8 @@ type Pricing struct {
|
|||||||
ModelPrice float64 `json:"model_price"`
|
ModelPrice float64 `json:"model_price"`
|
||||||
OwnerBy string `json:"owner_by"`
|
OwnerBy string `json:"owner_by"`
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
EnableGroup []string `json:"enable_groups"`
|
||||||
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -23,56 +27,98 @@ var (
|
|||||||
updatePricingLock sync.Mutex
|
updatePricingLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
|
modelSupportEndpointsLock = sync.RWMutex{}
|
||||||
|
)
|
||||||
|
|
||||||
func GetPricing() []Pricing {
|
func GetPricing() []Pricing {
|
||||||
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
updatePricingLock.Lock()
|
updatePricingLock.Lock()
|
||||||
defer updatePricingLock.Unlock()
|
defer updatePricingLock.Unlock()
|
||||||
|
// Double check after acquiring the lock
|
||||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
|
modelSupportEndpointsLock.Lock()
|
||||||
|
defer modelSupportEndpointsLock.Unlock()
|
||||||
updatePricing()
|
updatePricing()
|
||||||
}
|
}
|
||||||
//if group != "" {
|
}
|
||||||
// userPricingMap := make([]Pricing, 0)
|
|
||||||
// models := GetGroupModels(group)
|
|
||||||
// for _, pricing := range pricingMap {
|
|
||||||
// if !common.StringsContains(models, pricing.ModelName) {
|
|
||||||
// pricing.Available = false
|
|
||||||
// }
|
|
||||||
// userPricingMap = append(userPricingMap, pricing)
|
|
||||||
// }
|
|
||||||
// return userPricingMap
|
|
||||||
//}
|
|
||||||
return pricingMap
|
return pricingMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||||
|
if model == "" {
|
||||||
|
return make([]constant.EndpointType, 0)
|
||||||
|
}
|
||||||
|
modelSupportEndpointsLock.RLock()
|
||||||
|
defer modelSupportEndpointsLock.RUnlock()
|
||||||
|
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||||
|
return endpoints
|
||||||
|
}
|
||||||
|
return make([]constant.EndpointType, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func updatePricing() {
|
func updatePricing() {
|
||||||
//modelRatios := common.GetModelRatios()
|
//modelRatios := common.GetModelRatios()
|
||||||
enableAbilities := GetAllEnableAbilities()
|
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||||
modelGroupsMap := make(map[string][]string)
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelGroupsMap := make(map[string]*types.Set[string])
|
||||||
|
|
||||||
for _, ability := range enableAbilities {
|
for _, ability := range enableAbilities {
|
||||||
groups := modelGroupsMap[ability.Model]
|
groups, ok := modelGroupsMap[ability.Model]
|
||||||
if groups == nil {
|
if !ok {
|
||||||
groups = make([]string, 0)
|
groups = types.NewSet[string]()
|
||||||
}
|
|
||||||
if !common.StringsContains(groups, ability.Group) {
|
|
||||||
groups = append(groups, ability.Group)
|
|
||||||
}
|
|
||||||
modelGroupsMap[ability.Model] = groups
|
modelGroupsMap[ability.Model] = groups
|
||||||
}
|
}
|
||||||
|
groups.Add(ability.Group)
|
||||||
|
}
|
||||||
|
|
||||||
|
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||||
|
modelSupportEndpointsStr := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, ability := range enableAbilities {
|
||||||
|
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||||
|
if !ok {
|
||||||
|
endpoints = make([]string, 0)
|
||||||
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
|
}
|
||||||
|
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||||
|
for _, channelType := range channelTypes {
|
||||||
|
if !common.StringsContains(endpoints, string(channelType)) {
|
||||||
|
endpoints = append(endpoints, string(channelType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
|
for model, endpoints := range modelSupportEndpointsStr {
|
||||||
|
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||||
|
for _, endpointStr := range endpoints {
|
||||||
|
endpointType := constant.EndpointType(endpointStr)
|
||||||
|
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||||
|
}
|
||||||
|
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||||
|
}
|
||||||
|
|
||||||
pricingMap = make([]Pricing, 0)
|
pricingMap = make([]Pricing, 0)
|
||||||
for model, groups := range modelGroupsMap {
|
for model, groups := range modelGroupsMap {
|
||||||
pricing := Pricing{
|
pricing := Pricing{
|
||||||
ModelName: model,
|
ModelName: model,
|
||||||
EnableGroup: groups,
|
EnableGroup: groups.Items(),
|
||||||
|
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||||
}
|
}
|
||||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||||
if findPrice {
|
if findPrice {
|
||||||
pricing.ModelPrice = modelPrice
|
pricing.ModelPrice = modelPrice
|
||||||
pricing.QuotaType = 1
|
pricing.QuotaType = 1
|
||||||
} else {
|
} else {
|
||||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||||
pricing.ModelRatio = modelRatio
|
pricing.ModelRatio = modelRatio
|
||||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||||
pricing.QuotaType = 0
|
pricing.QuotaType = 0
|
||||||
}
|
}
|
||||||
pricingMap = append(pricingMap, pricing)
|
pricingMap = append(pricingMap, pricing)
|
||||||
|
|||||||
@@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) {
|
|||||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||||
return total, err
|
return total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
|
||||||
|
func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return 0, errors.New("ids 不能为空!")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Begin()
|
||||||
|
|
||||||
|
var tokens []Token
|
||||||
|
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.RedisEnabled {
|
||||||
|
gopool.Go(func() {
|
||||||
|
for _, t := range tokens {
|
||||||
|
_ = cacheDeleteToken(t.Key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(tokens), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
func cacheSetToken(token Token) error {
|
func cacheSetToken(token Token) error {
|
||||||
key := common.GenerateHMAC(token.Key)
|
key := common.GenerateHMAC(token.Key)
|
||||||
token.Clean()
|
token.Clean()
|
||||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ type User struct {
|
|||||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
||||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||||
|
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) ToBaseUser() *UserBase {
|
func (user *User) ToBaseUser() *UserBase {
|
||||||
@@ -113,7 +114,7 @@ func GetMaxUserId() int {
|
|||||||
return user.Id
|
return user.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
|
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
|
||||||
// Start transaction
|
// Start transaction
|
||||||
tx := DB.Begin()
|
tx := DB.Begin()
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
@@ -133,7 +134,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get paginated users within same transaction
|
// Get paginated users within same transaction
|
||||||
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
|
|||||||
"display_name": newUser.DisplayName,
|
"display_name": newUser.DisplayName,
|
||||||
"group": newUser.Group,
|
"group": newUser.Group,
|
||||||
"quota": newUser.Quota,
|
"quota": newUser.Quota,
|
||||||
|
"remark": newUser.Remark,
|
||||||
}
|
}
|
||||||
if updatePassword {
|
if updatePassword {
|
||||||
updates["password"] = newUser.Password
|
updates["password"] = newUser.Password
|
||||||
|
|||||||
@@ -24,12 +24,12 @@ type UserBase struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||||
c.Set("username", user.Username)
|
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
|||||||
return common.RedisHSetObj(
|
return common.RedisHSetObj(
|
||||||
getUserCacheKey(user.Id),
|
getUserCacheKey(user.Id),
|
||||||
user.ToBaseUser(),
|
user.ToBaseUser(),
|
||||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func batchUpdate() {
|
func batchUpdate() {
|
||||||
|
// check if there's any data to update
|
||||||
|
hasData := false
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateLocks[i].Lock()
|
||||||
|
if len(batchUpdateStores[i]) > 0 {
|
||||||
|
hasData = true
|
||||||
|
batchUpdateLocks[i].Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
batchUpdateLocks[i].Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasData {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
common.SysLog("batch update started")
|
common.SysLog("batch update started")
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateLocks[i].Lock()
|
batchUpdateLocks[i].Lock()
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
preConsumedTokens = promptTokens
|
preConsumedTokens = promptTokens
|
||||||
relayInfo.PromptTokens = promptTokens
|
relayInfo.PromptTokens = promptTokens
|
||||||
}
|
}
|
||||||
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
audioRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
|||||||
|
|
||||||
// FetchTask
|
// FetchTask
|
||||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||||
|
|
||||||
|
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
var fullRequestURL string
|
var fullRequestURL string
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
@@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
return embeddingRequestOpenAI2Ali(request), nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
|||||||
@@ -132,10 +132,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -35,10 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var aliResponse AliRerankResponse
|
var aliResponse AliRerankResponse
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
|
|||||||
@@ -39,34 +39,18 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var aliResponse AliEmbeddingResponse
|
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if aliResponse.Code != "" {
|
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: aliResponse.Message,
|
|
||||||
Type: aliResponse.Code,
|
|
||||||
Param: aliResponse.RequestId,
|
|
||||||
Code: aliResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
model := c.GetString("model")
|
model := c.GetString("model")
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "text-embedding-v4"
|
model = "text-embedding-v4"
|
||||||
}
|
}
|
||||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -186,10 +170,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,10 +180,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -166,10 +166,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,10 +176,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -215,10 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -280,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
res, err := service.GetImpatientHttpClient().Do(req)
|
res, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/openrouter"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -113,7 +114,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
// BudgetTokens 为 max_tokens 的 80%
|
// BudgetTokens 为 max_tokens 的 80%
|
||||||
claudeRequest.Thinking = &dto.Thinking{
|
claudeRequest.Thinking = &dto.Thinking{
|
||||||
Type: "enabled",
|
Type: "enabled",
|
||||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||||
}
|
}
|
||||||
// TODO: 临时处理
|
// TODO: 临时处理
|
||||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||||
@@ -122,6 +123,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if textRequest.Reasoning != nil {
|
||||||
|
var reasoning openrouter.RequestReasoning
|
||||||
|
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
budgetTokens := reasoning.MaxTokens
|
||||||
|
if budgetTokens > 0 {
|
||||||
|
claudeRequest.Thinking = &dto.Thinking{
|
||||||
|
Type: "enabled",
|
||||||
|
BudgetTokens: &budgetTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if textRequest.Stop != nil {
|
if textRequest.Stop != nil {
|
||||||
// stop maybe string/array string, convert to array string
|
// stop maybe string/array string, convert to array string
|
||||||
switch textRequest.Stop.(type) {
|
switch textRequest.Stop.(type) {
|
||||||
@@ -454,6 +470,7 @@ type ClaudeResponseInfo struct {
|
|||||||
Model string
|
Model string
|
||||||
ResponseText strings.Builder
|
ResponseText strings.Builder
|
||||||
Usage *dto.Usage
|
Usage *dto.Usage
|
||||||
|
Done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||||
@@ -461,20 +478,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
|||||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||||
} else {
|
} else {
|
||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
|
||||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||||
claudeInfo.Model = claudeResponse.Message.Model
|
claudeInfo.Model = claudeResponse.Message.Model
|
||||||
|
|
||||||
|
// message_start, 获取usage
|
||||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||||
|
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||||
|
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||||
|
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
if claudeResponse.Delta.Text != nil {
|
if claudeResponse.Delta.Text != nil {
|
||||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||||
}
|
}
|
||||||
|
if claudeResponse.Delta.Thinking != "" {
|
||||||
|
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||||
|
}
|
||||||
} else if claudeResponse.Type == "message_delta" {
|
} else if claudeResponse.Type == "message_delta" {
|
||||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
// 最终的usage获取
|
||||||
if claudeResponse.Usage.InputTokens > 0 {
|
if claudeResponse.Usage.InputTokens > 0 {
|
||||||
|
// 不叠加,只取最新的
|
||||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||||
}
|
}
|
||||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||||
|
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||||
|
|
||||||
|
// 判断是否完整
|
||||||
|
claudeInfo.Done = true
|
||||||
} else if claudeResponse.Type == "content_block_start" {
|
} else if claudeResponse.Type == "content_block_start" {
|
||||||
} else {
|
} else {
|
||||||
return false
|
return false
|
||||||
@@ -490,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
|||||||
|
|
||||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.DecodeJsonStr(data, &claudeResponse)
|
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||||
@@ -506,25 +535,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
|
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
|
||||||
} else {
|
} else {
|
||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
// message_start, 获取usage
|
||||||
info.UpstreamModelName = claudeResponse.Message.Model
|
info.UpstreamModelName = claudeResponse.Message.Model
|
||||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
|
||||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
|
||||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
|
||||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
|
||||||
} else if claudeResponse.Type == "message_delta" {
|
} else if claudeResponse.Type == "message_delta" {
|
||||||
if claudeResponse.Usage.InputTokens > 0 {
|
|
||||||
// 不叠加,只取最新的
|
|
||||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
|
||||||
}
|
|
||||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
|
||||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||||
@@ -544,29 +563,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
|
||||||
// 说明流模式建立失败,可能为官方出错
|
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
|
||||||
//usage.PromptTokens = info.PromptTokens
|
|
||||||
}
|
|
||||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
|
||||||
if requestMode == RequestModeCompletion {
|
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
} else {
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
//上游出错
|
//上游出错
|
||||||
}
|
}
|
||||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
if common.DebugEnabled {
|
||||||
|
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||||
|
}
|
||||||
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
|
//
|
||||||
|
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||||
|
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
@@ -604,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
|
|
||||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.DecodeJson(data, &claudeResponse)
|
err := common.UnmarshalJson(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -619,10 +634,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||||
@@ -645,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
case relaycommon.RelayFormatClaude:
|
case relaycommon.RelayFormatClaude:
|
||||||
responseData = data
|
responseData = data
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(http.StatusOK)
|
common.IOCopyBytesGracefully(c, nil, responseData)
|
||||||
_, err = c.Writer.Write(responseData)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
claudeInfo := &ClaudeResponseInfo{
|
claudeInfo := &ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
@@ -663,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println("responseBody: ", string(responseBody))
|
println("responseBody: ", string(responseBody))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
@@ -81,10 +81,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
@@ -94,10 +91,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var response dto.TextResponse
|
var response dto.TextResponse
|
||||||
err = json.Unmarshal(responseBody, &response)
|
err = json.Unmarshal(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -108,7 +102,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
responseText += choice.Message.StringContent()
|
responseText += choice.Message.StringContent()
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
response.Usage = *usage
|
response.Usage = *usage
|
||||||
response.Id = helper.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -127,10 +121,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &cfResp)
|
err = json.Unmarshal(responseBody, &cfResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -150,7 +141,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package cohere
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := helper.GetResponseID(c)
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
responseText := ""
|
responseText := ""
|
||||||
@@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
@@ -174,10 +173,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var cohereResp CohereResponseResult
|
var cohereResp CohereResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -218,10 +214,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var cohereResp CohereRerankResponseResult
|
var cohereResp CohereRerankResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -48,10 +48,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
// convert coze response to openai response
|
// convert coze response to openai response
|
||||||
var response dto.TextResponse
|
var response dto.TextResponse
|
||||||
var cozeResponse CozeChatDetailResponse
|
var cozeResponse CozeChatDetailResponse
|
||||||
@@ -106,7 +103,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
var currentEvent string
|
var currentEvent string
|
||||||
var currentData string
|
var currentData string
|
||||||
var usage dto.Usage
|
var usage = &dto.Usage{}
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
@@ -114,7 +111,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
if line == "" {
|
if line == "" {
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
// handle last event
|
// handle last event
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
currentEvent = ""
|
currentEvent = ""
|
||||||
currentData = ""
|
currentData = ""
|
||||||
}
|
}
|
||||||
@@ -134,7 +131,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
// Last event
|
// Last event
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
@@ -143,12 +140,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
client := service.GetImpatientHttpClient()
|
client := service.GetHttpClient()
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to send request: " + err.Error())
|
common.SysError("failed to send request: " + err.Error())
|
||||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
common.SysError("close_response_body_failed: " + err.Error())
|
|
||||||
}
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
usage.CompletionTokens += nodeToken
|
usage.CompletionTokens += nodeToken
|
||||||
return nil, usage
|
return nil, usage
|
||||||
@@ -264,10 +257,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &difyResponse)
|
err = json.Unmarshal(responseBody, &difyResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -72,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
// suffix -thinking and -nothinking
|
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||||
|
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||||
|
info.UpstreamModelName = parts[0]
|
||||||
|
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
|
|||||||
Seed int64 `json:"seed,omitempty"`
|
Seed int64 `json:"seed,omitempty"`
|
||||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
|
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatCandidate struct {
|
type GeminiChatCandidate struct {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -9,20 +8,19 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(string(responseBody))
|
println(string(responseBody))
|
||||||
@@ -30,28 +28,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
|
|
||||||
// 解析为 Gemini 原生响应格式
|
// 解析为 Gemini 原生响应格式
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否有候选响应
|
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
|
||||||
return nil, &dto.OpenAIErrorWithStatusCode{
|
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: "No candidates returned",
|
|
||||||
Type: "server_error",
|
|
||||||
Param: "",
|
|
||||||
Code: 500,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算使用量(基于 UsageMetadata)
|
// 计算使用量(基于 UsageMetadata)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||||
jsonResponse, err := json.Marshal(geminiResponse)
|
jsonResponse, err := common.EncodeJson(geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头并写入响应
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
if err != nil {
|
|
||||||
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &usage, nil
|
return &usage, nil
|
||||||
}
|
}
|
||||||
@@ -88,9 +67,11 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
responseText := strings.Builder{}
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
return false
|
return false
|
||||||
@@ -102,13 +83,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
imageCount++
|
imageCount++
|
||||||
}
|
}
|
||||||
|
if part.Text != "" {
|
||||||
|
responseText.WriteString(part.Text)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新使用量统计
|
// 更新使用量统计
|
||||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||||
@@ -121,7 +105,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 直接发送 GeminiChatResponse 响应
|
// 直接发送 GeminiChatResponse 响应
|
||||||
err = helper.ObjectData(c, geminiResponse)
|
err = helper.StringData(c, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
@@ -135,8 +119,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算最终使用量
|
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
if usage.CompletionTokens == 0 {
|
||||||
|
str := responseText.String()
|
||||||
|
if len(str) > 0 {
|
||||||
|
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
// 空补全,不需要使用量
|
||||||
|
usage = &dto.Usage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||||
//helper.Done(c)
|
//helper.Done(c)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
@@ -36,6 +37,102 @@ var geminiSupportedMimeTypes = map[string]bool{
|
|||||||
"video/flv": true,
|
"video/flv": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemini 允许的思考预算范围
|
||||||
|
const (
|
||||||
|
pro25MinBudget = 128
|
||||||
|
pro25MaxBudget = 32768
|
||||||
|
flash25MaxBudget = 24576
|
||||||
|
flash25LiteMinBudget = 512
|
||||||
|
flash25LiteMaxBudget = 24576
|
||||||
|
)
|
||||||
|
|
||||||
|
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||||
|
func clampThinkingBudget(modelName string, budget int) int {
|
||||||
|
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||||
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||||
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||||
|
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||||
|
|
||||||
|
if is25FlashLite {
|
||||||
|
if budget < flash25LiteMinBudget {
|
||||||
|
return flash25LiteMinBudget
|
||||||
|
}
|
||||||
|
if budget > flash25LiteMaxBudget {
|
||||||
|
return flash25LiteMaxBudget
|
||||||
|
}
|
||||||
|
} else if isNew25Pro {
|
||||||
|
if budget < pro25MinBudget {
|
||||||
|
return pro25MinBudget
|
||||||
|
}
|
||||||
|
if budget > pro25MaxBudget {
|
||||||
|
return pro25MaxBudget
|
||||||
|
}
|
||||||
|
} else { // 其他模型
|
||||||
|
if budget < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if budget > flash25MaxBudget {
|
||||||
|
return flash25MaxBudget
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return budget
|
||||||
|
}
|
||||||
|
|
||||||
|
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||||
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
|
modelName := info.UpstreamModelName
|
||||||
|
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||||
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||||
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||||
|
|
||||||
|
if strings.Contains(modelName, "-thinking-") {
|
||||||
|
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||||
|
if len(parts) == 2 && parts[1] != "" {
|
||||||
|
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||||
|
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
|
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||||
|
IncludeThoughts: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||||
|
unsupportedModels := []string{
|
||||||
|
"gemini-2.5-pro-preview-05-06",
|
||||||
|
"gemini-2.5-pro-preview-03-25",
|
||||||
|
}
|
||||||
|
isUnsupported := false
|
||||||
|
for _, unsupportedModel := range unsupportedModels {
|
||||||
|
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||||
|
isUnsupported = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUnsupported {
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
}
|
||||||
|
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
|
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||||
|
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||||
|
if !isNew25Pro {
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
|
ThinkingBudget: common.GetPointer(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||||
|
|
||||||
@@ -56,67 +153,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
ThinkingAdaptor(&geminiRequest, info)
|
||||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
|
||||||
// 硬编码不支持 ThinkingBudget 的旧模型
|
|
||||||
unsupportedModels := []string{
|
|
||||||
"gemini-2.5-pro-preview-05-06",
|
|
||||||
"gemini-2.5-pro-preview-03-25",
|
|
||||||
}
|
|
||||||
|
|
||||||
isUnsupported := false
|
|
||||||
for _, unsupportedModel := range unsupportedModels {
|
|
||||||
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
|
|
||||||
isUnsupported = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isUnsupported {
|
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
|
||||||
IncludeThoughts: true,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
|
||||||
|
|
||||||
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
|
|
||||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
|
||||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
|
||||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
|
||||||
|
|
||||||
if isNew25Pro {
|
|
||||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
|
||||||
if budgetTokens == 0 || budgetTokens < 128 {
|
|
||||||
budgetTokens = 128
|
|
||||||
} else if budgetTokens > 32768 {
|
|
||||||
budgetTokens = 32768
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 其他模型:ThinkingBudget范围为0-24576
|
|
||||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
|
||||||
budgetTokens = 24576
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
|
||||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
|
||||||
IncludeThoughts: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
|
||||||
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
|
|
||||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
|
||||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
|
||||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
|
||||||
|
|
||||||
if !isNew25Pro {
|
|
||||||
// 只有非新2.5pro模型才支持-nothinking
|
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
|
||||||
ThinkingBudget: common.GetPointer(0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||||
for _, category := range SafetySettingList {
|
for _, category := range SafetySettingList {
|
||||||
@@ -283,7 +320,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
|
|
||||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||||
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
|
url := part.GetImageMedia().Url
|
||||||
|
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||||
}
|
}
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
@@ -341,8 +379,10 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
if content.Role == "assistant" {
|
if content.Role == "assistant" {
|
||||||
content.Role = "model"
|
content.Role = "model"
|
||||||
}
|
}
|
||||||
|
if len(content.Parts) > 0 {
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(system_content) > 0 {
|
if len(system_content) > 0 {
|
||||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||||
@@ -611,9 +651,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: helper.GetResponseID(c),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
@@ -754,14 +794,14 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|||||||
|
|
||||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
// responseText := ""
|
// responseText := ""
|
||||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
id := helper.GetResponseID(c)
|
||||||
createAt := common.GetTimestamp()
|
createAt := common.GetTimestamp()
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var imageCount int
|
var imageCount int
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
return false
|
return false
|
||||||
@@ -826,15 +866,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(string(responseBody))
|
println(string(responseBody))
|
||||||
}
|
}
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -849,7 +886,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||||
fullTextResponse.Model = info.UpstreamModelName
|
fullTextResponse.Model = info.UpstreamModelName
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
@@ -880,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, readErr := io.ReadAll(resp.Body)
|
responseBody, readErr := io.ReadAll(resp.Body)
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
var geminiResponse GeminiEmbeddingResponse
|
var geminiResponse GeminiEmbeddingResponse
|
||||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
@@ -916,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
|
|||||||
}
|
}
|
||||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||||
|
|
||||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package jina
|
|||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"jina-clip-v1",
|
"jina-clip-v1",
|
||||||
"jina-reranker-v2-base-multilingual",
|
"jina-reranker-v2-base-multilingual",
|
||||||
|
"jina-reranker-m0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "jina"
|
var ChannelName = "jina"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
@@ -53,10 +54,7 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -80,4 +78,3 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &fullTextResponse.Usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -88,10 +88,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -120,31 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
// Copy headers
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
// 删除任何现有的相同头部,以防止重复添加头部
|
|
||||||
c.Writer.Header().Del(k)
|
|
||||||
for _, vv := range v {
|
|
||||||
c.Writer.Header().Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// reset content length
|
|
||||||
c.Writer.Header().Del("Content-Length")
|
|
||||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"one-api/common"
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
@@ -21,7 +20,7 @@ import (
|
|||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/common_handler"
|
"one-api/relay/common_handler"
|
||||||
"one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
|
|
||||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||||
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||||
IsFirstThinkingContent: true,
|
IsFirstThinkingContent: true,
|
||||||
SendLastThinkingContent: false,
|
SendLastThinkingContent: false,
|
||||||
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||||
baseUrl = "wss://" + baseUrl
|
baseUrl = "wss://" + baseUrl
|
||||||
@@ -79,29 +78,36 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch info.ChannelType {
|
switch info.ChannelType {
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
apiVersion := info.ApiVersion
|
apiVersion := info.ApiVersion
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = constant2.AzureDefaultAPIVersion
|
apiVersion = constant.AzureDefaultAPIVersion
|
||||||
}
|
}
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
|
||||||
|
// 特殊处理 responses API
|
||||||
|
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||||
|
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||||
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
|
}
|
||||||
|
|
||||||
model_ := info.UpstreamModelName
|
model_ := info.UpstreamModelName
|
||||||
// 2025年5月10日后创建的渠道不移除.
|
// 2025年5月10日后创建的渠道不移除.
|
||||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
}
|
}
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||||
}
|
}
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.GetRequestURL(info)
|
return minimax.GetRequestURL(info)
|
||||||
case common.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
url := info.BaseUrl
|
url := info.BaseUrl
|
||||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||||
return url, nil
|
return url, nil
|
||||||
@@ -112,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, header)
|
channel.SetupApiRequestHeader(info, c, header)
|
||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == constant.ChannelTypeAzure {
|
||||||
header.Set("api-key", info.ApiKey)
|
header.Set("api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
||||||
header.Set("OpenAI-Organization", info.Organization)
|
header.Set("OpenAI-Organization", info.Organization)
|
||||||
}
|
}
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||||
if swp != "" {
|
if swp != "" {
|
||||||
items := []string{
|
items := []string{
|
||||||
@@ -138,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
} else {
|
} else {
|
||||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||||
header.Set("X-Title", "New API")
|
header.Set("X-Title", "New API")
|
||||||
}
|
}
|
||||||
@@ -149,9 +155,14 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
|
||||||
request.StreamOptions = nil
|
request.StreamOptions = nil
|
||||||
}
|
}
|
||||||
|
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
|
if len(request.Usage) == 0 {
|
||||||
|
request.Usage = json.RawMessage(`{"include":true}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
if strings.HasPrefix(request.Model, "o") {
|
if strings.HasPrefix(request.Model, "o") {
|
||||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||||
request.MaxCompletionTokens = request.MaxTokens
|
request.MaxCompletionTokens = request.MaxTokens
|
||||||
@@ -193,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
a.ResponseFormat = request.ResponseFormat
|
a.ResponseFormat = request.ResponseFormat
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
jsonData, err := json.Marshal(request)
|
jsonData, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||||
@@ -242,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesEdits:
|
||||||
|
|
||||||
var requestBody bytes.Buffer
|
var requestBody bytes.Buffer
|
||||||
writer := multipart.NewWriter(&requestBody)
|
writer := multipart.NewWriter(&requestBody)
|
||||||
@@ -399,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
|
||||||
info.RelayMode == constant.RelayModeAudioTranslation ||
|
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
|
||||||
info.RelayMode == constant.RelayModeImagesEdits {
|
info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||||
return channel.DoFormRequest(a, c, info, requestBody)
|
return channel.DoFormRequest(a, c, info, requestBody)
|
||||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
} else if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
return channel.DoWssRequest(a, c, info, requestBody)
|
return channel.DoWssRequest(a, c, info, requestBody)
|
||||||
} else {
|
} else {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
@@ -412,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRealtime:
|
case relayconstant.RelayModeRealtime:
|
||||||
err, usage = OpenaiRealtimeHandler(c, info)
|
err, usage = OpenaiRealtimeHandler(c, info)
|
||||||
case constant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case constant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||||
case constant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
@@ -442,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
func (a *Adaptor) GetModelList() []string {
|
||||||
switch a.ChannelType {
|
switch a.ChannelType {
|
||||||
case common.ChannelType360:
|
case constant.ChannelType360:
|
||||||
return ai360.ModelList
|
return ai360.ModelList
|
||||||
case common.ChannelTypeMoonshot:
|
case constant.ChannelTypeMoonshot:
|
||||||
return moonshot.ModelList
|
return moonshot.ModelList
|
||||||
case common.ChannelTypeLingYiWanWu:
|
case constant.ChannelTypeLingYiWanWu:
|
||||||
return lingyiwanwu.ModelList
|
return lingyiwanwu.ModelList
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.ModelList
|
return minimax.ModelList
|
||||||
case common.ChannelTypeXinference:
|
case constant.ChannelTypeXinference:
|
||||||
return xinference.ModelList
|
return xinference.ModelList
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return openrouter.ModelList
|
return openrouter.ModelList
|
||||||
default:
|
default:
|
||||||
return ModelList
|
return ModelList
|
||||||
@@ -461,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
|
|||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
func (a *Adaptor) GetChannelName() string {
|
||||||
switch a.ChannelType {
|
switch a.ChannelType {
|
||||||
case common.ChannelType360:
|
case constant.ChannelType360:
|
||||||
return ai360.ChannelName
|
return ai360.ChannelName
|
||||||
case common.ChannelTypeMoonshot:
|
case constant.ChannelTypeMoonshot:
|
||||||
return moonshot.ChannelName
|
return moonshot.ChannelName
|
||||||
case common.ChannelTypeLingYiWanWu:
|
case constant.ChannelTypeLingYiWanWu:
|
||||||
return lingyiwanwu.ChannelName
|
return lingyiwanwu.ChannelName
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.ChannelName
|
return minimax.ChannelName
|
||||||
case common.ChannelTypeXinference:
|
case constant.ChannelTypeXinference:
|
||||||
return xinference.ChannelName
|
return xinference.ChannelName
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return openrouter.ChannelName
|
return openrouter.ChannelName
|
||||||
default:
|
default:
|
||||||
return ChannelName
|
return ChannelName
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
@@ -15,6 +14,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
@@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|||||||
}
|
}
|
||||||
|
|
||||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||||
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
|
if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,12 +110,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
containStreamUsage := false
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
model := info.UpstreamModelName
|
||||||
var responseId string
|
var responseId string
|
||||||
var createAt int64 = 0
|
var createAt int64 = 0
|
||||||
var systemFingerprint string
|
var systemFingerprint string
|
||||||
model := info.UpstreamModelName
|
var containStreamUsage bool
|
||||||
|
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
var toolCount int
|
var toolCount int
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
@@ -147,31 +148,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 处理最后的响应
|
||||||
shouldSendLastResp := true
|
shouldSendLastResp := true
|
||||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
||||||
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
|
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
||||||
if err == nil {
|
common.SysError("error handling last response: " + err.Error())
|
||||||
responseId = lastStreamResponse.Id
|
|
||||||
createAt = lastStreamResponse.Created
|
|
||||||
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
|
||||||
model = lastStreamResponse.Model
|
|
||||||
if service.ValidUsage(lastStreamResponse.Usage) {
|
|
||||||
containStreamUsage = true
|
|
||||||
usage = lastStreamResponse.Usage
|
|
||||||
if !info.ShouldIncludeUsage {
|
|
||||||
shouldSendLastResp = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, choice := range lastStreamResponse.Choices {
|
|
||||||
if choice.FinishReason != nil {
|
|
||||||
shouldSendLastResp = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldSendLastResp {
|
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||||
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理token计算
|
// 处理token计算
|
||||||
@@ -180,10 +165,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
if info.ChannelType == constant.ChannelTypeDeepSeek {
|
||||||
if usage.PromptCacheHitTokens != 0 {
|
if usage.PromptCacheHitTokens != 0 {
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
}
|
}
|
||||||
@@ -196,16 +181,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
var simpleResponse dto.OpenAITextResponse
|
var simpleResponse dto.OpenAITextResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = common.UnmarshalJson(responseBody, &simpleResponse)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = common.DecodeJson(responseBody, &simpleResponse)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -224,7 +207,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
@@ -237,7 +220,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case relaycommon.RelayFormatOpenAI:
|
||||||
if forceFormat {
|
if forceFormat {
|
||||||
responseBody, err = json.Marshal(simpleResponse)
|
responseBody, err = common.EncodeJson(simpleResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -246,29 +229,15 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case relaycommon.RelayFormatClaude:
|
||||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||||
claudeRespStr, err := json.Marshal(claudeResp)
|
claudeRespStr, err := common.EncodeJson(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
responseBody = claudeRespStr
|
responseBody = claudeRespStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset response body
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
common.SysError("error copying response body: " + err.Error())
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
return nil, &simpleResponse.Usage
|
return nil, &simpleResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||||
// and can be terminated directly.
|
// and can be terminated directly.
|
||||||
defer resp.Body.Close()
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.TotalTokens = info.PromptTokens
|
usage.TotalTokens = info.PromptTokens
|
||||||
@@ -296,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// count tokens by audio file duration
|
// count tokens by audio file duration
|
||||||
audioTokens, err := countAudioTokens(c)
|
audioTokens, err := countAudioTokens(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -305,25 +276,8 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
// 写入新的 response body
|
||||||
if err != nil {
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = audioTokens
|
usage.PromptTokens = audioTokens
|
||||||
@@ -345,13 +299,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
if err = c.ShouldBind(&reqBody); err != nil {
|
if err = c.ShouldBind(&reqBody); err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||||
reqFp, err := reqBody.File.Open()
|
reqFp, err := reqBody.File.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
defer reqFp.Close()
|
||||||
|
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*")
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
@@ -365,7 +320,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
@@ -413,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
}
|
}
|
||||||
|
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
err = json.Unmarshal(message, realtimeEvent)
|
err = common.UnmarshalJson(message, realtimeEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
return
|
return
|
||||||
@@ -473,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
}
|
}
|
||||||
info.SetFirstResponseTime()
|
info.SetFirstResponseTime()
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
err = json.Unmarshal(message, realtimeEvent)
|
err = common.UnmarshalJson(message, realtimeEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
return
|
return
|
||||||
@@ -520,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
localUsage = &dto.RealtimeUsage{}
|
localUsage = &dto.RealtimeUsage{}
|
||||||
// print now usage
|
// print now usage
|
||||||
}
|
}
|
||||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||||
|
|
||||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||||
realtimeSession := realtimeEvent.Session
|
realtimeSession := realtimeEvent.Session
|
||||||
@@ -599,40 +554,25 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
// reset content length
|
|
||||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var usageResp dto.SimpleResponse
|
var usageResp dto.SimpleResponse
|
||||||
err = json.Unmarshal(responseBody, &usageResp)
|
err = common.UnmarshalJson(responseBody, &usageResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 写入新的 response body
|
||||||
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
// Once we've written to the client, we should not return errors anymore
|
||||||
|
// because the upstream has already consumed resources and returned content
|
||||||
|
// We should still perform billing even if parsing fails
|
||||||
// format
|
// format
|
||||||
if usageResp.InputTokens > 0 {
|
if usageResp.InputTokens > 0 {
|
||||||
usageResp.PromptTokens += usageResp.InputTokens
|
usageResp.PromptTokens += usageResp.InputTokens
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,17 +15,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// read response body
|
// read response body
|
||||||
var responsesResponse dto.OpenAIResponsesResponse
|
var responsesResponse dto.OpenAIResponsesResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = common.UnmarshalJson(responseBody, &responsesResponse)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = common.DecodeJson(responseBody, &responsesResponse)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -41,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset response body
|
// 写入新的 response body
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
// copy response body
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error copying response body: " + err.Error())
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
// compute usage
|
// compute usage
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||||
@@ -82,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
|
|
||||||
// 检查当前数据是否包含 completed 状态和 usage 信息
|
// 检查当前数据是否包含 completed 状态和 usage 信息
|
||||||
var streamResponse dto.ResponsesStreamResponse
|
var streamResponse dto.ResponsesStreamResponse
|
||||||
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
|
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
||||||
sendResponsesStreamData(c, streamResponse, data)
|
sendResponsesStreamData(c, streamResponse, data)
|
||||||
switch streamResponse.Type {
|
switch streamResponse.Type {
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
@@ -110,7 +94,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
tempStr := responseTextBuilder.String()
|
tempStr := responseTextBuilder.String()
|
||||||
if len(tempStr) > 0 {
|
if len(tempStr) > 0 {
|
||||||
// 非正常结束,使用输出文本的 token 数量
|
// 非正常结束,使用输出文本的 token 数量
|
||||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||||
usage.CompletionTokens = completionTokens
|
usage.CompletionTokens = completionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package palm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
|||||||
|
|
||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := helper.GetResponseID(c)
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
@@ -84,12 +83,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
|||||||
stopChan <- true
|
stopChan <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
common.SysError("error closing stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
var palmResponse PaLMChatResponse
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,10 +117,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,10 +126,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
var palmResponse PaLMChatResponse
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -156,7 +144,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
@@ -14,10 +15,7 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var siliconflowResp SFRerankResponse
|
var siliconflowResp SFRerankResponse
|
||||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
380
relay/channel/task/jimeng/adaptor.go
Normal file
380
relay/channel/task/jimeng/adaptor.go
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
package jimeng
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"one-api/model"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Request / Response structures
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type requestPayload struct {
|
||||||
|
ReqKey string `json:"req_key"`
|
||||||
|
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||||
|
ImageUrls []string `json:"image_urls,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Seed int64 `json:"seed"`
|
||||||
|
AspectRatio string `json:"aspect_ratio"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type responsePayload struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
Data struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseTask struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
|
||||||
|
ImageUrls interface{} `json:"image_urls"`
|
||||||
|
RespData string `json:"resp_data"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
VideoUrl string `json:"video_url"`
|
||||||
|
} `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
TimeElapsed string `json:"time_elapsed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Adaptor implementation
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
a.baseURL = info.BaseUrl
|
||||||
|
|
||||||
|
// apiKey format: "access_key|secret_key"
|
||||||
|
keyParts := strings.Split(info.ApiKey, "|")
|
||||||
|
if len(keyParts) == 2 {
|
||||||
|
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||||
|
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
|
action := constant.TaskActionGenerate
|
||||||
|
info.Action = action
|
||||||
|
|
||||||
|
req := relaycommon.TaskSubmitReq{}
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store into context for later usage
|
||||||
|
c.Set("task_request", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestHeader sets required headers.
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody converts request into Jimeng specific format.
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||||
|
v, exists := c.Get("task_request")
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("request not found in context")
|
||||||
|
}
|
||||||
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
|
body, err := a.convertToRequestPayload(&req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request payload failed")
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest delegates to common helper.
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse handles upstream response, returns taskID etc.
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
// Parse Jimeng response
|
||||||
|
var jResp responsePayload
|
||||||
|
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if jResp.Code != 10000 {
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
|
||||||
|
return jResp.Data.TaskID, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchTask fetch task status
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
taskID, ok := body["task_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid task_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||||
|
payload := map[string]string{
|
||||||
|
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||||
|
"task_id": taskID,
|
||||||
|
}
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
keyParts := strings.Split(key, "|")
|
||||||
|
if len(keyParts) != 2 {
|
||||||
|
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||||
|
}
|
||||||
|
accessKey := strings.TrimSpace(keyParts[0])
|
||||||
|
secretKey := strings.TrimSpace(keyParts[1])
|
||||||
|
|
||||||
|
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "sign request failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return service.GetHttpClient().Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string {
|
||||||
|
return []string{"jimeng_vgfm_t2v_l20"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string {
|
||||||
|
return "jimeng"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
|
||||||
|
var bodyBytes []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if req.Body != nil {
|
||||||
|
bodyBytes, err = io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read request body failed")
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
|
||||||
|
} else {
|
||||||
|
bodyBytes = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadHash := sha256.Sum256(bodyBytes)
|
||||||
|
hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
||||||
|
|
||||||
|
t := time.Now().UTC()
|
||||||
|
xDate := t.Format("20060102T150405Z")
|
||||||
|
shortDate := t.Format("20060102")
|
||||||
|
|
||||||
|
req.Header.Set("Host", req.URL.Host)
|
||||||
|
req.Header.Set("X-Date", xDate)
|
||||||
|
req.Header.Set("X-Content-Sha256", hexPayloadHash)
|
||||||
|
|
||||||
|
// Sort and encode query parameters to create canonical query string
|
||||||
|
queryParams := req.URL.Query()
|
||||||
|
sortedKeys := make([]string, 0, len(queryParams))
|
||||||
|
for k := range queryParams {
|
||||||
|
sortedKeys = append(sortedKeys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedKeys)
|
||||||
|
var queryParts []string
|
||||||
|
for _, k := range sortedKeys {
|
||||||
|
values := queryParams[k]
|
||||||
|
sort.Strings(values)
|
||||||
|
for _, v := range values {
|
||||||
|
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
canonicalQueryString := strings.Join(queryParts, "&")
|
||||||
|
|
||||||
|
headersToSign := map[string]string{
|
||||||
|
"host": req.URL.Host,
|
||||||
|
"x-date": xDate,
|
||||||
|
"x-content-sha256": hexPayloadHash,
|
||||||
|
}
|
||||||
|
if req.Header.Get("Content-Type") != "" {
|
||||||
|
headersToSign["content-type"] = req.Header.Get("Content-Type")
|
||||||
|
}
|
||||||
|
|
||||||
|
var signedHeaderKeys []string
|
||||||
|
for k := range headersToSign {
|
||||||
|
signedHeaderKeys = append(signedHeaderKeys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(signedHeaderKeys)
|
||||||
|
|
||||||
|
var canonicalHeaders strings.Builder
|
||||||
|
for _, k := range signedHeaderKeys {
|
||||||
|
canonicalHeaders.WriteString(k)
|
||||||
|
canonicalHeaders.WriteString(":")
|
||||||
|
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
|
||||||
|
canonicalHeaders.WriteString("\n")
|
||||||
|
}
|
||||||
|
signedHeaders := strings.Join(signedHeaderKeys, ";")
|
||||||
|
|
||||||
|
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||||
|
req.Method,
|
||||||
|
req.URL.Path,
|
||||||
|
canonicalQueryString,
|
||||||
|
canonicalHeaders.String(),
|
||||||
|
signedHeaders,
|
||||||
|
hexPayloadHash,
|
||||||
|
)
|
||||||
|
|
||||||
|
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
|
||||||
|
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
|
||||||
|
|
||||||
|
region := "cn-north-1"
|
||||||
|
serviceName := "cv"
|
||||||
|
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
|
||||||
|
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
||||||
|
xDate,
|
||||||
|
credentialScope,
|
||||||
|
hexHashedCanonicalRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
|
||||||
|
kRegion := hmacSHA256(kDate, []byte(region))
|
||||||
|
kService := hmacSHA256(kRegion, []byte(serviceName))
|
||||||
|
kSigning := hmacSHA256(kService, []byte("request"))
|
||||||
|
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
|
||||||
|
|
||||||
|
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||||
|
accessKey,
|
||||||
|
credentialScope,
|
||||||
|
signedHeaders,
|
||||||
|
signature,
|
||||||
|
)
|
||||||
|
req.Header.Set("Authorization", authorization)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hmacSHA256(key []byte, data []byte) []byte {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||||
|
r := requestPayload{
|
||||||
|
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
AspectRatio: "16:9", // Default aspect ratio
|
||||||
|
Seed: -1, // Default to random
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle one-of image_urls or binary_data_base64
|
||||||
|
if req.Image != "" {
|
||||||
|
if strings.HasPrefix(req.Image, "http") {
|
||||||
|
r.ImageUrls = []string{req.Image}
|
||||||
|
} else {
|
||||||
|
r.BinaryDataBase64 = []string{req.Image}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata := req.Metadata
|
||||||
|
medaBytes, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(medaBytes, &r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||||
|
}
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||||
|
resTask := responseTask{}
|
||||||
|
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||||
|
}
|
||||||
|
taskResult := relaycommon.TaskInfo{}
|
||||||
|
if resTask.Code == 10000 {
|
||||||
|
taskResult.Code = 0
|
||||||
|
} else {
|
||||||
|
taskResult.Code = resTask.Code // todo uni code
|
||||||
|
taskResult.Reason = resTask.Message
|
||||||
|
taskResult.Status = model.TaskStatusFailure
|
||||||
|
taskResult.Progress = "100%"
|
||||||
|
}
|
||||||
|
switch resTask.Data.Status {
|
||||||
|
case "in_queue":
|
||||||
|
taskResult.Status = model.TaskStatusQueued
|
||||||
|
taskResult.Progress = "10%"
|
||||||
|
case "done":
|
||||||
|
taskResult.Status = model.TaskStatusSuccess
|
||||||
|
taskResult.Progress = "100%"
|
||||||
|
}
|
||||||
|
taskResult.Url = resTask.Data.VideoUrl
|
||||||
|
return &taskResult, nil
|
||||||
|
}
|
||||||
346
relay/channel/task/kling/adaptor.go
Normal file
346
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
package kling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Request / Response structures
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type SubmitReq struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Duration int `json:"duration,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestPayload struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Duration string `json:"duration,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||||
|
ModelName string `json:"model_name,omitempty"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type responsePayload struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
Data struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
TaskStatusMsg string `json:"task_status_msg"`
|
||||||
|
TaskResult struct {
|
||||||
|
Videos []struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
Duration string `json:"duration"`
|
||||||
|
} `json:"videos"`
|
||||||
|
} `json:"task_result"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Adaptor implementation
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
a.baseURL = info.BaseUrl
|
||||||
|
|
||||||
|
// apiKey format: "access_key|secret_key"
|
||||||
|
keyParts := strings.Split(info.ApiKey, "|")
|
||||||
|
if len(keyParts) == 2 {
|
||||||
|
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||||
|
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
|
action := constant.TaskActionGenerate
|
||||||
|
info.Action = action
|
||||||
|
|
||||||
|
var req SubmitReq
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store into context for later usage
|
||||||
|
c.Set("task_request", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
|
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||||
|
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestHeader sets required headers.
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||||
|
token, err := a.createJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody converts request into Kling specific format.
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||||
|
v, exists := c.Get("task_request")
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("request not found in context")
|
||||||
|
}
|
||||||
|
req := v.(SubmitReq)
|
||||||
|
|
||||||
|
body, err := a.convertToRequestPayload(&req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest delegates to common helper.
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
if action := c.GetString("action"); action != "" {
|
||||||
|
info.Action = action
|
||||||
|
}
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse handles upstream response, returns taskID etc.
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt Kling response parse first.
|
||||||
|
var kResp responsePayload
|
||||||
|
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
|
||||||
|
return kResp.Data.TaskId, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback generic task response.
|
||||||
|
var generic dto.TaskResponse[string]
|
||||||
|
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generic.IsSuccess() {
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||||
|
return generic.Data, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchTask fetch task status
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
taskID, ok := body["task_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid task_id")
|
||||||
|
}
|
||||||
|
action, ok := body["action"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid action")
|
||||||
|
}
|
||||||
|
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||||
|
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := a.createJWTTokenWithKey(key)
|
||||||
|
if err != nil {
|
||||||
|
token = key
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
|
||||||
|
return service.GetHttpClient().Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string {
|
||||||
|
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string {
|
||||||
|
return "kling"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||||
|
r := requestPayload{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Image: req.Image,
|
||||||
|
Mode: defaultString(req.Mode, "std"),
|
||||||
|
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||||
|
AspectRatio: a.getAspectRatio(req.Size),
|
||||||
|
ModelName: req.Model,
|
||||||
|
CfgScale: 0.5,
|
||||||
|
}
|
||||||
|
if r.ModelName == "" {
|
||||||
|
r.ModelName = "kling-v1"
|
||||||
|
}
|
||||||
|
metadata := req.Metadata
|
||||||
|
medaBytes, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(medaBytes, &r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||||
|
}
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||||
|
switch size {
|
||||||
|
case "1024x1024", "512x512":
|
||||||
|
return "1:1"
|
||||||
|
case "1280x720", "1920x1080":
|
||||||
|
return "16:9"
|
||||||
|
case "720x1280", "1080x1920":
|
||||||
|
return "9:16"
|
||||||
|
default:
|
||||||
|
return "1:1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultString(s, def string) string {
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultInt(v int, def int) int {
|
||||||
|
if v == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// JWT helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||||
|
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||||
|
parts := strings.Split(apiKey, "|")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||||
|
}
|
||||||
|
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||||
|
if accessKey == "" || secretKey == "" {
|
||||||
|
return "", fmt.Errorf("access key and secret key are required")
|
||||||
|
}
|
||||||
|
now := time.Now().Unix()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": accessKey,
|
||||||
|
"exp": now + 1800, // 30 minutes
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
token.Header["typ"] = "JWT"
|
||||||
|
return token.SignedString([]byte(secretKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||||
|
resPayload := responsePayload{}
|
||||||
|
err := json.Unmarshal(respBody, &resPayload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||||
|
}
|
||||||
|
taskInfo := &relaycommon.TaskInfo{}
|
||||||
|
taskInfo.Code = resPayload.Code
|
||||||
|
taskInfo.TaskID = resPayload.Data.TaskId
|
||||||
|
taskInfo.Reason = resPayload.Message
|
||||||
|
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||||||
|
status := resPayload.Data.TaskStatus
|
||||||
|
switch status {
|
||||||
|
case "submitted":
|
||||||
|
taskInfo.Status = model.TaskStatusSubmitted
|
||||||
|
case "processing":
|
||||||
|
taskInfo.Status = model.TaskStatusInProgress
|
||||||
|
case "succeed":
|
||||||
|
taskInfo.Status = model.TaskStatusSuccess
|
||||||
|
case "failed":
|
||||||
|
taskInfo.Status = model.TaskStatusFailure
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown task status: %s", status)
|
||||||
|
}
|
||||||
|
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||||||
|
video := videos[0]
|
||||||
|
taskInfo.Url = video.Url
|
||||||
|
}
|
||||||
|
return taskInfo, nil
|
||||||
|
}
|
||||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
|||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||||
|
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||||
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user