Merge branch 'main' into feat-copy-models
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
name: Publish Docker image (amd64)
|
||||
name: Publish Docker image (alpha)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
branches:
|
||||
- alpha
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
description: "reason"
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
push_to_registries:
|
||||
name: Push Docker image to multiple registries
|
||||
@@ -22,7 +23,7 @@ jobs:
|
||||
|
||||
- name: Save version info
|
||||
run: |
|
||||
git describe --tags > VERSION
|
||||
echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -37,6 +38,9 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -44,11 +48,15 @@ jobs:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=raw,value=alpha
|
||||
type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
|
||||
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
7
.github/workflows/docker-image-arm64.yml
vendored
7
.github/workflows/docker-image-arm64.yml
vendored
@@ -1,14 +1,9 @@
|
||||
name: Publish Docker image (arm64)
|
||||
name: Publish Docker image (Multi Registries)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
jobs:
|
||||
push_to_registries:
|
||||
name: Push Docker image to multiple registries
|
||||
|
||||
13
.github/workflows/linux-release.yml
vendored
13
.github/workflows/linux-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
@@ -15,16 +20,16 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-node@v3
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
node-version: 18
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
|
||||
13
.github/workflows/macos-release.yml
vendored
13
.github/workflows/macos-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
@@ -15,16 +20,16 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-node@v3
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
node-version: 18
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
|
||||
13
.github/workflows/windows-release.yml
vendored
13
.github/workflows/windows-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
@@ -18,16 +23,16 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-node@v3
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
node-version: 18
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
|
||||
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
|
||||
@@ -44,6 +44,9 @@
|
||||
|
||||
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
|
||||
You can also access the AI-generated DeepWiki:
|
||||
[](https://deepwiki.com/QuantumNous/new-api)
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
|
||||
@@ -110,6 +113,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env
|
||||
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
|
||||
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
|
||||
- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
|
||||
|
||||
## Deployment
|
||||
|
||||
|
||||
@@ -27,6 +27,9 @@
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
<a href="https://coderabbit.ai">
|
||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -44,6 +47,9 @@
|
||||
|
||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
|
||||
也可访问AI生成的DeepWiki:
|
||||
[](https://deepwiki.com/QuantumNous/new-api)
|
||||
|
||||
## ✨ 主要特性
|
||||
|
||||
New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
|
||||
@@ -110,6 +116,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
|
||||
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
|
||||
- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
|
||||
|
||||
## 部署
|
||||
|
||||
@@ -176,7 +183,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
|
||||
其他基于New API的项目:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
||||
|
||||
## 帮助支持
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ const (
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
}
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package common
|
||||
|
||||
const (
|
||||
DatabaseTypeMySQL = "mysql"
|
||||
DatabaseTypeSQLite = "sqlite"
|
||||
DatabaseTypePostgreSQL = "postgres"
|
||||
)
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
|
||||
@@ -92,12 +92,12 @@ func RedisDel(key string) error {
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHDelObj(key string) error {
|
||||
func RedisDelKey(key string) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
|
||||
SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.HDel(ctx, key).Err()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
txn.Expire(ctx, key, expiration)
|
||||
|
||||
// 只有在 expiration 大于 0 时才设置过期时间
|
||||
if expiration > 0 {
|
||||
txn.Expire(ctx, key, expiration)
|
||||
}
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
}
|
||||
|
||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
// 使用函数来避免初始化顺序带来的赋值问题
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return common.SyncFrequency
|
||||
}
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
|
||||
@@ -5,6 +5,7 @@ type TaskPlatform string
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,6 +7,7 @@ var (
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case common.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
|
||||
@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -165,8 +168,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
@@ -200,10 +203,10 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
content, _ := json.Marshal("hi")
|
||||
|
||||
testMessage := dto.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
Content: "hi",
|
||||
}
|
||||
testRequest.Model = model
|
||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||
@@ -271,6 +274,13 @@ func testAllChannels(notify bool) error {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
}
|
||||
gopool.Go(func() {
|
||||
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
||||
defer func() {
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
}()
|
||||
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
@@ -305,9 +315,7 @@ func testAllChannels(notify bool) error {
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
|
||||
@@ -43,22 +43,31 @@ type OpenAIModelsResponse struct {
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
|
||||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
@@ -69,21 +78,39 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
|
||||
// 计算 tag 总数用于分页
|
||||
total, _ = model.CountAllTags()
|
||||
} else if typeFilter >= 0 {
|
||||
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountChannelsByType(typeFilter)
|
||||
} else {
|
||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountAllChannels()
|
||||
}
|
||||
|
||||
// calculate type counts
|
||||
typeCounts, _ := model.CountChannelsGroupByType()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -119,8 +146,11 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case common.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
@@ -207,10 +237,20 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -620,3 +660,44 @@ func BatchSetChannelTag(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetTagModels(c *gin.Context) {
|
||||
tag := c.Query("tag")
|
||||
if tag == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "tag不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var longestModels string
|
||||
maxLength := 0
|
||||
|
||||
// Find the longest models string among all channels with the given tag
|
||||
for _, channel := range channels {
|
||||
if channel.Models != "" {
|
||||
currentModels := strings.Split(channel.Models, ",")
|
||||
if len(currentModels) > maxLength {
|
||||
maxLength = len(currentModels)
|
||||
longestModels = channel.Models
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": longestModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
103
controller/console_migrate.go
Normal file
103
controller/console_migrate.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// 用于迁移检测的旧键,该文件下个版本会删除
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||
func MigrateConsoleSetting(c *gin.Context) {
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
@@ -1,15 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
// 解析其他查询参数
|
||||
@@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
log.Printf("userId = %d \n", userId)
|
||||
|
||||
queryParams := model.TaskQueryParams{
|
||||
MjID: c.Query("mj_id"),
|
||||
@@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
@@ -24,57 +26,85 @@ func TestStatus(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取HTTP统计信息
|
||||
httpStats := middleware.GetStats()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
"http_stats": httpStats,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
if cs.ApiInfoEnabled {
|
||||
data["api_info"] = console_setting.GetApiInfo()
|
||||
}
|
||||
if cs.AnnouncementsEnabled {
|
||||
data["announcements"] = console_setting.GetAnnouncements()
|
||||
}
|
||||
if cs.FAQEnabled {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
},
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -134,6 +137,9 @@ func init() {
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -119,7 +121,42 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -20,12 +21,18 @@ func GetPricing(c *gin.Context) {
|
||||
user, err := model.GetUserCache(userId.(int))
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -40,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -49,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
322
controller/ratio_sync.go
Normal file
322
controller/ratio_sync.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data map[string]any `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
|
||||
return
|
||||
}
|
||||
ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
}
|
||||
}
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -126,6 +127,10 @@ func AddRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := common.GetUUID()
|
||||
@@ -135,6 +140,7 @@ func AddRedemption(c *gin.Context) {
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
Quota: redemption.Quota,
|
||||
ExpiredTime: redemption.ExpiredTime,
|
||||
}
|
||||
err = cleanRedemption.Insert()
|
||||
if err != nil {
|
||||
@@ -191,12 +197,18 @@ func UpdateRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
} else {
|
||||
if statusOnly == "" {
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// If you add more fields, please also update redemption.Update()
|
||||
cleanRedemption.Name = redemption.Name
|
||||
cleanRedemption.Quota = redemption.Quota
|
||||
cleanRedemption.ExpiredTime = redemption.ExpiredTime
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
@@ -213,3 +225,27 @@ func UpdateRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func validateExpiredTime(expired int64) error {
|
||||
if expired != 0 && expired < common.GetTimestamp() {
|
||||
return errors.New("过期时间不能早于当前时间")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
err = relay.EmbeddingHelper(c)
|
||||
case relayconstant.RelayModeResponses:
|
||||
err = relay.ResponsesHelper(c)
|
||||
case relayconstant.RelayModeGemini:
|
||||
err = relay.GeminiHelper(c)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
@@ -257,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
}
|
||||
@@ -386,7 +388,7 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
@@ -418,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
|
||||
|
||||
// If root doesn't exist, validate and create admin account
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
|
||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling:
|
||||
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
@@ -224,9 +226,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
// 解析其他查询参数
|
||||
@@ -237,24 +244,32 @@ func GetAllTask(c *gin.Context) {
|
||||
Action: c.Query("action"),
|
||||
StartTimestamp: startTimestamp,
|
||||
EndTimestamp: endTimestamp,
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
@@ -271,14 +286,17 @@ func GetUserTask(c *gin.Context) {
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
140
controller/task_video.go
Normal file
140
controller/task_video.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
var responseItem map[string]interface{}
|
||||
err = json.Unmarshal(responseBody, &responseItem)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
|
||||
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
code, _ := responseItem["code"].(float64)
|
||||
if code != 0 {
|
||||
return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
}
|
||||
|
||||
data, ok := responseItem["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
|
||||
return fmt.Errorf("video task data format error for task %s", taskId)
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
|
||||
if status, ok := data["task_status"].(string); ok {
|
||||
switch status {
|
||||
case "submitted", "queued":
|
||||
task.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
task.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
task.Status = model.TaskStatusSuccess
|
||||
task.Progress = "100%"
|
||||
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
|
||||
task.FailReason = url
|
||||
} else {
|
||||
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
|
||||
}
|
||||
case "failed":
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if reason, ok := data["fail_reason"].(string); ok {
|
||||
task.FailReason = reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If task failed, refund quota
|
||||
if task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
size, _ := strconv.Atoi(c.Query("size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if size <= 0 {
|
||||
size = common.ItemsPerPage
|
||||
} else if size > 100 {
|
||||
size = 100
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, p*size, size)
|
||||
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Get total count for pagination
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
"data": gin.H{
|
||||
"items": tokens,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": size,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,16 +97,14 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
payType := "wxpay"
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = "wxpay"
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
|
||||
154
controller/uptime_kuma.go
Normal file
154
controller/uptime_kuma.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/setting/console_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
requestTimeout = 30 * time.Second
|
||||
httpTimeout = 10 * time.Second
|
||||
uptimeKeySuffix = "_24"
|
||||
apiStatusPath = "/api/status-page/"
|
||||
apiHeartbeatPath = "/api/status-page/heartbeat/"
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
Name string `json:"name"`
|
||||
Uptime float64 `json:"uptime"`
|
||||
Status int `json:"status"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New("non-200 status")
|
||||
}
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(dest)
|
||||
}
|
||||
|
||||
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
} `json:"heartbeatList"`
|
||||
UptimeList map[string]float64 `json:"uptimeList"`
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, pg := range statusData.PublicGroupList {
|
||||
if len(pg.MonitorList) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range pg.MonitorList {
|
||||
monitor := Monitor{
|
||||
Name: m.Name,
|
||||
Group: pg.Name,
|
||||
}
|
||||
|
||||
monitorID := strconv.Itoa(m.ID)
|
||||
|
||||
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
|
||||
monitor.Uptime = uptime
|
||||
}
|
||||
|
||||
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
|
||||
monitor.Status = heartbeats[0].Status
|
||||
}
|
||||
|
||||
result.Monitors = append(result.Monitors, monitor)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func GetUptimeKumaStatus(c *gin.Context) {
|
||||
groups := console_setting.GetUptimeKumaGroups()
|
||||
if len(groups) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
g.Go(func() error {
|
||||
results[i] = fetchGroupData(gCtx, client, group)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -459,6 +462,9 @@ func GetSelf(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -943,6 +949,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
|
||||
func UpdateUserSetting(c *gin.Context) {
|
||||
@@ -1019,6 +1026,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
constant.UserSettingRecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
|
||||
137
dto/claude.go
137
dto/claude.go
@@ -1,29 +1,33 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetText(s string) {
|
||||
@@ -38,15 +42,39 @@ func (c *ClaudeMediaMessage) GetText() string {
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
||||
var content string
|
||||
return json.Unmarshal(c.Content, &content) == nil
|
||||
if c.Content == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := c.Content.(string)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetStringContent() string {
|
||||
var content string
|
||||
if err := json.Unmarshal(c.Content, &content); err == nil {
|
||||
return content
|
||||
if c.Content == nil {
|
||||
return ""
|
||||
}
|
||||
switch c.Content.(type) {
|
||||
case string:
|
||||
return c.Content.(string)
|
||||
case []any:
|
||||
var contentStr string
|
||||
for _, contentItem := range c.Content.([]any) {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == ContentTypeText {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -56,16 +84,12 @@ func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetContent(content any) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
c.Content = jsonContent
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||
var mediaContent []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
|
||||
return mediaContent
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
@@ -81,14 +105,36 @@ type ClaudeMessage struct {
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) IsStringContent() bool {
|
||||
if c.Content == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := c.Content.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) GetStringContent() string {
|
||||
if c.IsStringContent() {
|
||||
return c.Content.(string)
|
||||
if c.Content == nil {
|
||||
return ""
|
||||
}
|
||||
switch c.Content.(type) {
|
||||
case string:
|
||||
return c.Content.(string)
|
||||
case []any:
|
||||
var contentStr string
|
||||
for _, contentItem := range c.Content.([]any) {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == ContentTypeText {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -97,15 +143,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.Content)
|
||||
var contentList []ClaudeMediaMessage
|
||||
err := json.Unmarshal(jsonContent, &contentList)
|
||||
if err != nil {
|
||||
return make([]ClaudeMediaMessage, 0), err
|
||||
}
|
||||
return contentList, nil
|
||||
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
@@ -140,7 +178,14 @@ type ClaudeRequest struct {
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Thinking) GetBudgetTokens() int {
|
||||
if c.BudgetTokens == nil {
|
||||
return 0
|
||||
}
|
||||
return *c.BudgetTokens
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||
@@ -160,14 +205,8 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.System)
|
||||
var contentList []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
|
||||
return contentList
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
|
||||
@@ -14,6 +14,8 @@ type ImageRequest struct {
|
||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
@@ -2,6 +2,7 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -18,41 +19,54 @@ type FormatJsonSchema struct {
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
//Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
ExtraBody any `json:"extra_body,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions json.RawMessage `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.DecodeJson(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
@@ -72,11 +86,11 @@ type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
return int(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -96,16 +110,16 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
parsedContent []MediaContent
|
||||
parsedStringContent *string
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
parsedContent []MediaContent
|
||||
//parsedStringContent *string
|
||||
}
|
||||
|
||||
type MediaContent struct {
|
||||
@@ -115,25 +129,56 @@ type MediaContent struct {
|
||||
InputAudio any `json:"input_audio,omitempty"`
|
||||
File any `json:"file,omitempty"`
|
||||
VideoUrl any `json:"video_url,omitempty"`
|
||||
// OpenRouter Params
|
||||
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
|
||||
if m.ImageUrl != nil {
|
||||
return m.ImageUrl.(*MessageImageUrl)
|
||||
if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
|
||||
return m.ImageUrl.(*MessageImageUrl)
|
||||
}
|
||||
if itemMap, ok := m.ImageUrl.(map[string]any); ok {
|
||||
out := &MessageImageUrl{
|
||||
Url: common.Interface2String(itemMap["url"]),
|
||||
Detail: common.Interface2String(itemMap["detail"]),
|
||||
MimeType: common.Interface2String(itemMap["mime_type"]),
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
|
||||
if m.InputAudio != nil {
|
||||
return m.InputAudio.(*MessageInputAudio)
|
||||
if _, ok := m.InputAudio.(*MessageInputAudio); ok {
|
||||
return m.InputAudio.(*MessageInputAudio)
|
||||
}
|
||||
if itemMap, ok := m.InputAudio.(map[string]any); ok {
|
||||
out := &MessageInputAudio{
|
||||
Data: common.Interface2String(itemMap["data"]),
|
||||
Format: common.Interface2String(itemMap["format"]),
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetFile() *MessageFile {
|
||||
if m.File != nil {
|
||||
return m.File.(*MessageFile)
|
||||
if _, ok := m.File.(*MessageFile); ok {
|
||||
return m.File.(*MessageFile)
|
||||
}
|
||||
if itemMap, ok := m.File.(map[string]any); ok {
|
||||
out := &MessageFile{
|
||||
FileName: common.Interface2String(itemMap["file_name"]),
|
||||
FileData: common.Interface2String(itemMap["file_data"]),
|
||||
FileId: common.Interface2String(itemMap["file_id"]),
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -199,6 +244,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
|
||||
}
|
||||
|
||||
func (m *Message) StringContent() string {
|
||||
switch m.Content.(type) {
|
||||
case string:
|
||||
return m.Content.(string)
|
||||
case []any:
|
||||
var contentStr string
|
||||
for _, contentItem := range m.Content.([]any) {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == ContentTypeText {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Message) SetNullContent() {
|
||||
m.Content = nil
|
||||
m.parsedContent = nil
|
||||
}
|
||||
|
||||
func (m *Message) SetStringContent(content string) {
|
||||
m.Content = content
|
||||
m.parsedContent = nil
|
||||
}
|
||||
|
||||
func (m *Message) SetMediaContent(content []MediaContent) {
|
||||
m.Content = content
|
||||
m.parsedContent = content
|
||||
}
|
||||
|
||||
func (m *Message) IsStringContent() bool {
|
||||
_, ok := m.Content.(string)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) ParseContent() []MediaContent {
|
||||
if m.Content == nil {
|
||||
return nil
|
||||
}
|
||||
if len(m.parsedContent) > 0 {
|
||||
return m.parsedContent
|
||||
}
|
||||
|
||||
var contentList []MediaContent
|
||||
// 先尝试解析为字符串
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
contentList = []MediaContent{{
|
||||
Type: ContentTypeText,
|
||||
Text: content,
|
||||
}}
|
||||
m.parsedContent = contentList
|
||||
return contentList
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
//var arrayContent []map[string]interface{}
|
||||
|
||||
arrayContent, ok := m.Content.([]any)
|
||||
if !ok {
|
||||
return contentList
|
||||
}
|
||||
|
||||
for _, contentItemAny := range arrayContent {
|
||||
mediaItem, ok := contentItemAny.(MediaContent)
|
||||
if ok {
|
||||
contentList = append(contentList, mediaItem)
|
||||
continue
|
||||
}
|
||||
|
||||
contentItem, ok := contentItemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
contentType, ok := contentItem["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case ContentTypeText:
|
||||
if text, ok := contentItem["text"].(string); ok {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
case ContentTypeImageURL:
|
||||
imageUrl := contentItem["image_url"]
|
||||
temp := &MessageImageUrl{
|
||||
Detail: "high",
|
||||
}
|
||||
switch v := imageUrl.(type) {
|
||||
case string:
|
||||
temp.Url = v
|
||||
case map[string]interface{}:
|
||||
url, ok1 := v["url"].(string)
|
||||
detail, ok2 := v["detail"].(string)
|
||||
if ok2 {
|
||||
temp.Detail = detail
|
||||
}
|
||||
if ok1 {
|
||||
temp.Url = url
|
||||
}
|
||||
}
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: temp,
|
||||
})
|
||||
|
||||
case ContentTypeInputAudio:
|
||||
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
|
||||
data, ok1 := audioData["data"].(string)
|
||||
format, ok2 := audioData["format"].(string)
|
||||
if ok1 && ok2 {
|
||||
temp := &MessageInputAudio{
|
||||
Data: data,
|
||||
Format: format,
|
||||
}
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: temp,
|
||||
})
|
||||
}
|
||||
}
|
||||
case ContentTypeFile:
|
||||
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
|
||||
fileId, ok3 := fileData["file_id"].(string)
|
||||
if ok3 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeFile,
|
||||
File: &MessageFile{
|
||||
FileId: fileId,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
fileName, ok1 := fileData["filename"].(string)
|
||||
fileDataStr, ok2 := fileData["file_data"].(string)
|
||||
if ok1 && ok2 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeFile,
|
||||
File: &MessageFile{
|
||||
FileName: fileName,
|
||||
FileData: fileDataStr,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
case ContentTypeVideoUrl:
|
||||
if videoUrl, ok := contentItem["video_url"].(string); ok {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeVideoUrl,
|
||||
VideoUrl: &MessageVideoUrl{
|
||||
Url: videoUrl,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentList) > 0 {
|
||||
m.parsedContent = contentList
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
// old code
|
||||
/*func (m *Message) StringContent() string {
|
||||
if m.parsedStringContent != nil {
|
||||
return *m.parsedStringContent
|
||||
}
|
||||
@@ -369,6 +594,11 @@ func (m *Message) ParseContent() []MediaContent {
|
||||
m.parsedContent = contentList
|
||||
}
|
||||
return contentList
|
||||
}*/
|
||||
|
||||
type WebSearchOptions struct {
|
||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIResponsesRequest struct {
|
||||
|
||||
49
dto/ratio_sync.go
Normal file
49
dto/ratio_sync.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package dto
|
||||
|
||||
// UpstreamDTO 提交到后端同步倍率的上游渠道信息
|
||||
// Endpoint 可以为空,后端会默认使用 /api/ratio_config
|
||||
// BaseURL 必须以 http/https 开头,不要以 / 结尾
|
||||
// 例如: https://api.example.com
|
||||
// Endpoint: /api/ratio_config
|
||||
// 提交示例:
|
||||
// {
|
||||
// "name": "openai",
|
||||
// "base_url": "https://api.openai.com",
|
||||
// "endpoint": "/ratio_config"
|
||||
// }
|
||||
|
||||
type UpstreamDTO struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
// Current 为本地值,可能为 nil
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
}
|
||||
|
||||
// SyncableChannel 可同步的渠道信息(base_url 不为空)
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dto
|
||||
|
||||
type VideoRequest struct {
|
||||
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||
Width int `json:"width" example:"512"` // Video width
|
||||
Height int `json:"height" example:"512"` // Video height
|
||||
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||
}
|
||||
|
||||
// VideoResponse 视频生成提交任务后的响应
|
||||
type VideoResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||
type VideoTaskResponse struct {
|
||||
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||
}
|
||||
|
||||
// VideoTaskMetadata 视频任务元数据
|
||||
type VideoTaskMetadata struct {
|
||||
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||
Width int `json:"width" example:"512"` // 实际宽度
|
||||
Height int `json:"height" example:"512"` // 实际高度
|
||||
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||
}
|
||||
|
||||
// VideoTaskError 视频任务错误信息
|
||||
type VideoTaskError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -11,7 +11,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/sonic v1.11.6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -25,10 +24,10 @@ require (
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
@@ -43,12 +42,13 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -38,8 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
@@ -167,8 +167,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
@@ -197,6 +195,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
|
||||
27
main.go
27
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -74,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
ratio_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
@@ -89,13 +89,28 @@ func main() {
|
||||
if common.MemoryCacheEnabled {
|
||||
common.SysLog("memory cache enabled")
|
||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||
model.InitChannelCache()
|
||||
}
|
||||
if common.MemoryCacheEnabled {
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
|
||||
// Add panic recovery and retry for InitChannelCache
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||
// Retry once
|
||||
_, fixErr := model.FixAbility()
|
||||
if fixErr != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
model.InitChannelCache()
|
||||
}()
|
||||
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
}
|
||||
|
||||
// 热更新配置
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
|
||||
// 数据看板
|
||||
go model.UpdateQuotaData()
|
||||
|
||||
|
||||
2
makefile
2
makefile
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func validUserInfo(username string, role int) bool {
|
||||
@@ -182,6 +183,18 @@ func TokenAuth() func(c *gin.Context) {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
// gemini api 从query中获取key
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
skKey := c.Query("key")
|
||||
if skKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||
}
|
||||
// 从x-goog-api-key header中获取key
|
||||
xGoogKey := c.Request.Header.Get("x-goog-api-key")
|
||||
if xGoogKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
|
||||
}
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
showGroup := userGroup
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
@@ -162,6 +170,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformKling))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||
if modelName != "" {
|
||||
modelRequest.Model = modelName
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
@@ -244,3 +269,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||
// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
// 输出: gemini-2.0-flash
|
||||
func extractModelNameFromGeminiPath(path string) string {
|
||||
// 查找 "/models/" 的位置
|
||||
modelsPrefix := "/models/"
|
||||
modelsIndex := strings.Index(path, modelsPrefix)
|
||||
if modelsIndex == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 从 "/models/" 之后开始提取
|
||||
startIndex := modelsIndex + len(modelsPrefix)
|
||||
if startIndex >= len(path) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 查找 ":" 的位置,模型名在 ":" 之前
|
||||
colonIndex := strings.Index(path[startIndex:], ":")
|
||||
if colonIndex == -1 {
|
||||
// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
|
||||
return path[startIndex:]
|
||||
}
|
||||
|
||||
// 返回模型名部分
|
||||
return path[startIndex : startIndex+colonIndex]
|
||||
}
|
||||
|
||||
41
middleware/stats.go
Normal file
41
middleware/stats.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HTTPStats 存储HTTP统计信息
|
||||
type HTTPStats struct {
|
||||
activeConnections int64
|
||||
}
|
||||
|
||||
var globalStats = &HTTPStats{}
|
||||
|
||||
// StatsMiddleware 统计中间件
|
||||
func StatsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 增加活跃连接数
|
||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||
|
||||
// 确保在请求结束时减少连接数
|
||||
defer func() {
|
||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// StatsInfo 统计信息结构
|
||||
type StatsInfo struct {
|
||||
ActiveConnections int64 `json:"active_connections"`
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func GetStats() StatsInfo {
|
||||
return StatsInfo{
|
||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@@ -23,7 +24,7 @@ type Ability struct {
|
||||
func GetGroupModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -41,15 +42,11 @@ func GetAllEnableAbilities() []Ability {
|
||||
}
|
||||
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
@@ -75,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,9 +126,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -152,7 +151,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Create(&chunk).Error
|
||||
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -194,9 +193,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
// Then add new abilities
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -212,7 +217,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
|
||||
if len(abilities) > 0 {
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err = tx.Create(&chunk).Error
|
||||
err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
if isNewTx {
|
||||
tx.Rollback()
|
||||
@@ -261,12 +266,28 @@ func FixAbility() (int, error) {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
// Delete abilities of channels that are not in channel table
|
||||
err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
|
||||
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
|
||||
if len(channelIds) > 0 {
|
||||
// Process deletion in chunks to avoid "too many placeholders" error
|
||||
for _, chunk := range lo.Chunk(channelIds, 100) {
|
||||
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no channels exist, delete all abilities
|
||||
err = DB.Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
common.SysLog("Delete all abilities successfully")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
|
||||
count += len(channelIds)
|
||||
|
||||
@@ -275,17 +296,26 @@ func FixAbility() (int, error) {
|
||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
return count, err
|
||||
}
|
||||
|
||||
var channels []Channel
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
// Process query in chunks to avoid "too many placeholders" error
|
||||
err = nil
|
||||
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
|
||||
var channelsChunk []Channel
|
||||
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
}
|
||||
channels = append(channels, channelsChunk...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,10 +5,13 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
@@ -16,6 +19,9 @@ var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
@@ -72,7 +78,43 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||
var channel *Channel
|
||||
var err error
|
||||
selectGroup := group
|
||||
if group == "auto" {
|
||||
if len(setting.AutoGroups) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
if common.DebugEnabled {
|
||||
println("autoGroup:", autoGroup)
|
||||
}
|
||||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
selectGroup = autoGroup
|
||||
if common.DebugEnabled {
|
||||
println("selectGroup:", selectGroup)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, group, errors.New("channel not found")
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
@@ -84,11 +126,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
}
|
||||
|
||||
|
||||
channelSyncLock.RLock()
|
||||
channels := group2model2channels[group][model]
|
||||
channelSyncLock.RUnlock()
|
||||
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
@@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string {
|
||||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetGroups() []string {
|
||||
if channel.Group == "" {
|
||||
return []string{}
|
||||
}
|
||||
groups := strings.Split(strings.Trim(channel.Group, ","), ",")
|
||||
for i, group := range groups {
|
||||
groups[i] = strings.TrimSpace(group)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
@@ -134,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -142,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -467,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -475,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -572,3 +583,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// CountAllChannels returns total channels in DB
|
||||
func CountAllChannels() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// CountAllTags returns number of non-empty distinct tags
|
||||
func CountAllTags() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
55
model/log.go
55
model/log.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,6 +33,7 @@ type Log struct {
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -61,7 +63,7 @@ func formatUserLogs(logs []*Log) {
|
||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
var tk Token
|
||||
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
||||
@@ -95,6 +97,15 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -111,7 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -128,6 +145,15 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
}
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -144,7 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -184,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = tx.Where("logs.channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -195,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0)
|
||||
channelIdsMap := make(map[int]struct{})
|
||||
channelMap := make(map[int]string)
|
||||
for _, log := range logs {
|
||||
if log.ChannelId != 0 {
|
||||
channelIds = append(channelIds, log.ChannelId)
|
||||
channelIdsMap[log.ChannelId] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0, len(channelIdsMap))
|
||||
for channelId := range channelIdsMap {
|
||||
channelIds = append(channelIds, channelId)
|
||||
}
|
||||
if len(channelIds) > 0 {
|
||||
var channels []struct {
|
||||
Id int `gorm:"column:id"`
|
||||
@@ -242,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = tx.Where("logs.created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -303,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
tx = tx.Where(logGroupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
|
||||
172
model/main.go
172
model/main.go
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,18 +16,48 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
var keyCol string
|
||||
var commonGroupCol string
|
||||
var commonKeyCol string
|
||||
var commonTrueVal string
|
||||
var commonFalseVal string
|
||||
|
||||
var logKeyCol string
|
||||
var logGroupCol string
|
||||
|
||||
func initCol() {
|
||||
// init common column names
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
keyCol = `"key"`
|
||||
|
||||
commonGroupCol = `"group"`
|
||||
commonKeyCol = `"key"`
|
||||
commonTrueVal = "true"
|
||||
commonFalseVal = "false"
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
keyCol = "`key`"
|
||||
commonGroupCol = "`group`"
|
||||
commonKeyCol = "`key`"
|
||||
commonTrueVal = "1"
|
||||
commonFalseVal = "0"
|
||||
}
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
switch common.LogSqlType {
|
||||
case common.DatabaseTypePostgreSQL:
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
default:
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
} else {
|
||||
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||
if common.UsingPostgreSQL {
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
} else {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -83,7 +114,7 @@ func CheckSetup() {
|
||||
}
|
||||
}
|
||||
|
||||
func chooseDB(envName string) (*gorm.DB, error) {
|
||||
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
defer func() {
|
||||
initCol()
|
||||
}()
|
||||
@@ -92,7 +123,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
if !isLog {
|
||||
common.UsingPostgreSQL = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypePostgreSQL
|
||||
}
|
||||
return gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dsn,
|
||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||
@@ -102,7 +137,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
}
|
||||
if strings.HasPrefix(dsn, "local") {
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
if !isLog {
|
||||
common.UsingSQLite = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypeSQLite
|
||||
}
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -117,7 +156,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
common.UsingMySQL = true
|
||||
if !isLog {
|
||||
common.UsingMySQL = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypeMySQL
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -131,7 +174,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
func InitDB() (err error) {
|
||||
db, err := chooseDB("SQL_DSN")
|
||||
db, err := chooseDB("SQL_DSN", false)
|
||||
if err == nil {
|
||||
if common.DebugEnabled {
|
||||
db = db.Debug()
|
||||
@@ -149,7 +192,7 @@ func InitDB() (err error) {
|
||||
return nil
|
||||
}
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
err = migrateDB()
|
||||
@@ -165,7 +208,7 @@ func InitLogDB() (err error) {
|
||||
LOG_DB = DB
|
||||
return
|
||||
}
|
||||
db, err := chooseDB("LOG_SQL_DSN")
|
||||
db, err := chooseDB("LOG_SQL_DSN", true)
|
||||
if err == nil {
|
||||
if common.DebugEnabled {
|
||||
db = db.Debug()
|
||||
@@ -198,54 +241,73 @@ func InitLogDB() (err error) {
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
err := DB.AutoMigrate(&Channel{})
|
||||
if !common.UsingPostgreSQL {
|
||||
return migrateDBFast()
|
||||
}
|
||||
err := DB.AutoMigrate(
|
||||
&Channel{},
|
||||
&Token{},
|
||||
&User{},
|
||||
&Option{},
|
||||
&Redemption{},
|
||||
&Ability{},
|
||||
&Log{},
|
||||
&Midjourney{},
|
||||
&TopUp{},
|
||||
&QuotaData{},
|
||||
&Task{},
|
||||
&Setup{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Token{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDBFast() error {
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12) // Buffer size matches number of migrations
|
||||
|
||||
migrations := []struct {
|
||||
model interface{}
|
||||
name string
|
||||
}{
|
||||
{&Channel{}, "Channel"},
|
||||
{&Token{}, "Token"},
|
||||
{&User{}, "User"},
|
||||
{&Option{}, "Option"},
|
||||
{&Redemption{}, "Redemption"},
|
||||
{&Ability{}, "Ability"},
|
||||
{&Log{}, "Log"},
|
||||
{&Midjourney{}, "Midjourney"},
|
||||
{&TopUp{}, "TopUp"},
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Setup{}, "Setup"},
|
||||
}
|
||||
err = DB.AutoMigrate(&User{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
for _, m := range migrations {
|
||||
wg.Add(1)
|
||||
go func(model interface{}, name string) {
|
||||
defer wg.Done()
|
||||
if err := DB.AutoMigrate(model); err != nil {
|
||||
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
||||
}
|
||||
}(m.model, m.name)
|
||||
}
|
||||
err = DB.AutoMigrate(&Option{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Wait for all migrations to complete
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = DB.AutoMigrate(&Redemption{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Ability{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Log{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Midjourney{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&TopUp{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&QuotaData{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Task{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Setup{})
|
||||
common.SysLog("database migrated")
|
||||
//err = createRootAccountIfNeed()
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateLOGDB() error {
|
||||
|
||||
@@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
// CountAllTasks returns total midjourney tasks for admin query
|
||||
func CountAllTasks(queryParams TaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Midjourney{})
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.MjID != "" {
|
||||
query = query.Where("mj_id = ?", queryParams.MjID)
|
||||
}
|
||||
if queryParams.StartTimestamp != "" {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != "" {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
// CountAllUserTask returns total midjourney tasks for user
|
||||
func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
|
||||
if queryParams.MjID != "" {
|
||||
query = query.Where("mj_id = ?", queryParams.MjID)
|
||||
}
|
||||
if queryParams.StartTimestamp != "" {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != "" {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,6 +77,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -94,12 +98,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -122,6 +127,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
@@ -191,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ImageDownloadPermission = intValue
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
@@ -260,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -286,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
@@ -351,17 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||
case "GroupGroupRatio":
|
||||
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -378,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -65,14 +65,14 @@ func updatePricing() {
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -21,6 +21,7 @@ type Redemption struct {
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
|
||||
}
|
||||
|
||||
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
|
||||
@@ -131,6 +132,9 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
||||
return errors.New("该兑换码已被使用")
|
||||
}
|
||||
if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
|
||||
return errors.New("该兑换码已过期")
|
||||
}
|
||||
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -162,7 +166,7 @@ func (redemption *Redemption) SelectUpdate() error {
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (redemption *Redemption) Update() error {
|
||||
var err error
|
||||
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
|
||||
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -183,3 +187,9 @@ func DeleteRedemptionById(id int) (err error) {
|
||||
}
|
||||
return redemption.Delete()
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemptions() (int64, error) {
|
||||
now := common.GetTimestamp()
|
||||
result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
|
||||
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
||||
return stat, err
|
||||
}
|
||||
|
||||
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
|
||||
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Task{})
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.Platform != "" {
|
||||
query = query.Where("platform = ?", queryParams.Platform)
|
||||
}
|
||||
if queryParams.UserID != "" {
|
||||
query = query.Where("user_id = ?", queryParams.UserID)
|
||||
}
|
||||
if len(queryParams.UserIDs) != 0 {
|
||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||
}
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
// TaskCountAllUserTask returns total tasks for given user
|
||||
func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Task{}).Where("user_id = ?", userId)
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.Platform != "" {
|
||||
query = query.Where("platform = ?", queryParams.Platform)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
|
||||
if token != "" {
|
||||
token = strings.Trim(token, "sk-")
|
||||
}
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
|
||||
return token, err
|
||||
}
|
||||
|
||||
@@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// CountUserTokens returns total number of tokens for the given user, used for pagination
|
||||
func CountUserTokens(userId int) (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func cacheSetToken(token Token) error {
|
||||
|
||||
func cacheDeleteToken(key string) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
|
||||
err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
}
|
||||
|
||||
func (user *User) ToBaseUser() *UserBase {
|
||||
@@ -175,7 +176,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
// 如果是数字,同时搜索ID和其他字段
|
||||
likeCondition = "id = ? OR " + likeCondition
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
@@ -184,7 +185,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
} else {
|
||||
// 非数字关键字,只搜索字符串字段
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
"display_name": newUser.DisplayName,
|
||||
"group": newUser.Group,
|
||||
"quota": newUser.Quota,
|
||||
"remark": newUser.Remark,
|
||||
}
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
@@ -615,7 +617,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ package model
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
)
|
||||
|
||||
@@ -57,7 +58,7 @@ func invalidateUserCache(userId int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisHDelObj(getUserCacheKey(userId))
|
||||
return common.RedisDelKey(getUserCacheKey(userId))
|
||||
}
|
||||
|
||||
// updateUserCache updates all user cache fields using hash
|
||||
@@ -69,7 +70,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
// check if there's any data to update
|
||||
hasData := false
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
if len(batchUpdateStores[i]) > 0 {
|
||||
hasData = true
|
||||
batchUpdateLocks[i].Unlock()
|
||||
break
|
||||
}
|
||||
batchUpdateLocks[i].Unlock()
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
|
||||
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
audioRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
|
||||
ParseResultUrl(resp map[string]any) (string, error)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
@@ -57,6 +59,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
|
||||
if !info.IsStream {
|
||||
request.EnableThinking = false
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
default:
|
||||
aliReq := requestOpenAI2Ali(*request)
|
||||
@@ -70,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return ConvertRerankRequest(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
@@ -97,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
|
||||
@@ -8,6 +8,7 @@ var ModelList = []string{
|
||||
"qwq-32b",
|
||||
"qwen3-235b-a22b",
|
||||
"text-embedding-v1",
|
||||
"gte-rerank-v2",
|
||||
}
|
||||
|
||||
var ChannelName = "ali"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package ali
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type AliMessage struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
@@ -97,3 +99,28 @@ type AliImageRequest struct {
|
||||
} `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankParameters struct {
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankInput struct {
|
||||
Query string `json:"query"`
|
||||
Documents []any `json:"documents"`
|
||||
}
|
||||
|
||||
type AliRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliRerankInput `json:"input"`
|
||||
Parameters AliRerankParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankResponse struct {
|
||||
Output struct {
|
||||
Results []dto.RerankResponseResult `json:"results"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
RequestId string `json:"request_id"`
|
||||
AliError
|
||||
}
|
||||
|
||||
83
relay/channel/ali/rerank.go
Normal file
83
relay/channel/ali/rerank.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
returnDocuments := request.ReturnDocuments
|
||||
if returnDocuments == nil {
|
||||
t := true
|
||||
returnDocuments = &t
|
||||
}
|
||||
return &AliRerankRequest{
|
||||
Model: request.Model,
|
||||
Input: AliRerankInput{
|
||||
Query: request.Query,
|
||||
Documents: request.Documents,
|
||||
},
|
||||
Parameters: AliRerankParameters{
|
||||
TopN: &request.TopN,
|
||||
ReturnDocuments: returnDocuments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
PromptTokens: aliResponse.Usage.TotalTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: aliResponse.Usage.TotalTokens,
|
||||
}
|
||||
rerankResponse := dto.RerankResponse{
|
||||
Results: aliResponse.Output.Results,
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package ali
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,6 +10,8 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
@@ -27,9 +28,6 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
if request.Model == "" {
|
||||
request.Model = "text-embedding-v1"
|
||||
}
|
||||
return &AliEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: struct {
|
||||
@@ -64,7 +62,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
model = "text-embedding-v4"
|
||||
}
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -75,11 +77,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Model: model,
|
||||
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
@@ -94,12 +96,11 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Output.Text)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
|
||||
@@ -104,6 +104,105 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
return targetConn, nil
|
||||
}
|
||||
|
||||
func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
|
||||
pingerCtx, stopPinger := context.WithCancel(context.Background())
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
// 增加panic恢复处理
|
||||
if r := recover(); r != nil {
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
|
||||
}
|
||||
}
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine stopped.")
|
||||
}
|
||||
}()
|
||||
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = helper.DefaultPingInterval
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
// 确保在任何情况下都清理ticker
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping ticker stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
var pingMutex sync.Mutex
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine started")
|
||||
}
|
||||
|
||||
// 增加超时控制,防止goroutine长时间运行
|
||||
maxPingDuration := 120 * time.Minute // 最大ping持续时间
|
||||
pingTimeout := time.NewTimer(maxPingDuration)
|
||||
defer pingTimeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
// 发送 ping 数据
|
||||
case <-ticker.C:
|
||||
if err := sendPingData(c, &pingMutex); err != nil {
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping error, stopping goroutine:", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
// 收到退出信号
|
||||
case <-pingerCtx.Done():
|
||||
return
|
||||
// request 结束
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
// 超时保护,防止goroutine无限运行
|
||||
case <-pingTimeout.C:
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine timeout, stopping")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return stopPinger
|
||||
}
|
||||
|
||||
func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
// 增加超时控制,防止锁死等待
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
err := helper.PingData(c)
|
||||
if err != nil {
|
||||
common2.LogError(c, "SSE ping error: "+err.Error())
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping data sent.")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
// 设置发送ping数据的超时时间
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("SSE ping data send timeout")
|
||||
case <-c.Request.Context().Done():
|
||||
return errors.New("request context cancelled during ping")
|
||||
}
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error
|
||||
@@ -115,68 +214,36 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
// 流式请求 ping 保活
|
||||
var stopPinger func()
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
pingEnabled := generalSettings.PingIntervalEnabled
|
||||
var pingerWg sync.WaitGroup
|
||||
|
||||
var stopPinger context.CancelFunc
|
||||
if info.IsStream {
|
||||
helper.SetEventStreamHeaders(c)
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
var pingerCtx context.Context
|
||||
pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
|
||||
|
||||
if pingEnabled {
|
||||
pingerWg.Add(1)
|
||||
gopool.Go(func() {
|
||||
defer pingerWg.Done()
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = helper.DefaultPingInterval
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
defer ticker.Stop()
|
||||
var pingMutex sync.Mutex
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine started")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
pingMutex.Lock()
|
||||
err2 := helper.PingData(c)
|
||||
pingMutex.Unlock()
|
||||
if err2 != nil {
|
||||
common2.LogError(c, "SSE ping error: "+err.Error())
|
||||
return
|
||||
}
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping data sent.")
|
||||
}
|
||||
case <-pingerCtx.Done():
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine stopped.")
|
||||
}
|
||||
return
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
defer func() {
|
||||
if stopPinger != nil {
|
||||
stopPinger()
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine stopped by defer")
|
||||
}
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
// request结束后停止ping
|
||||
if info.IsStream && pingEnabled {
|
||||
stopPinger()
|
||||
pingerWg.Wait()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
|
||||
@@ -11,6 +11,8 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -41,6 +43,16 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -53,12 +53,11 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Result)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: response.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-search") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
|
||||
request.Model = info.UpstreamModelName
|
||||
toMap := request.ToMap()
|
||||
toMap["web_search"] = map[string]any{
|
||||
"enable": true,
|
||||
"enable_citation": true,
|
||||
"enable_trace": true,
|
||||
"enable_status": false,
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
} else {
|
||||
a.RequestMode = RequestModeMessage
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,10 @@ var ModelList = []string{
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-7-sonnet-20250219-thinking",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4-20250514-thinking",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-opus-4-20250514-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -48,9 +48,9 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
|
||||
} else if message.Role == "system" {
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
@@ -155,15 +155,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
}
|
||||
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
|
||||
if lastMessage.IsStringContent() && message.IsStringContent() {
|
||||
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||||
fmtMessage.Content = content
|
||||
fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||||
// delete last message
|
||||
formatMessages = formatMessages[:len(formatMessages)-1]
|
||||
}
|
||||
}
|
||||
if fmtMessage.Content == nil {
|
||||
content, _ := json.Marshal("...")
|
||||
fmtMessage.Content = content
|
||||
fmtMessage.SetStringContent("...")
|
||||
}
|
||||
formatMessages = append(formatMessages, fmtMessage)
|
||||
lastMessage = fmtMessage
|
||||
@@ -397,12 +395,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
thinkingContent := ""
|
||||
|
||||
if reqMode == RequestModeCompletion {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
@@ -457,6 +454,7 @@ type ClaudeResponseInfo struct {
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
@@ -464,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
|
||||
// message_start, 获取usage
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
@@ -509,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
@@ -547,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -622,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||
}
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
|
||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
@@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
@@ -195,11 +194,10 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
content, _ := json.Marshal(cohereResp.Text)
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: dto.Message{Content: content, Role: "assistant"},
|
||||
Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
|
||||
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ type CozeError struct {
|
||||
type CozeEnterMessage struct {
|
||||
Role string `json:"role"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
MetaData json.RawMessage `json:"meta_data,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"`
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
var currentEvent string
|
||||
var currentData string
|
||||
var usage dto.Usage
|
||||
var usage = &dto.Usage{}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
if line == "" {
|
||||
if currentEvent != "" && currentData != "" {
|
||||
// handle last event
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
currentEvent = ""
|
||||
currentData = ""
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
// Last event
|
||||
if currentEvent != "" && currentData != "" {
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return nil, usage
|
||||
@@ -278,12 +271,11 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: difyResponse.MetaData.Usage,
|
||||
}
|
||||
content, _ := json.Marshal(difyResponse.Answer)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: difyResponse.Answer,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
@@ -71,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
@@ -165,6 +169,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if info.IsStream {
|
||||
return GeminiTextGenerationStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return GeminiTextGenerationHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return GeminiImageHandler(c, resp, info)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
@@ -22,19 +24,38 @@ type GeminiInlineData struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
|
||||
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiInlineData // Use type alias to avoid recursion
|
||||
var aux struct {
|
||||
Alias
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
|
||||
|
||||
// Prioritize snake_case if present
|
||||
if aux.MimeTypeSnake != "" {
|
||||
g.MimeType = aux.MimeTypeSnake
|
||||
} else if aux.MimeType != "" { // Fallback to camelCase from Alias
|
||||
g.MimeType = aux.MimeType
|
||||
}
|
||||
// g.Data would be populated by aux.Alias.Data
|
||||
return nil
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
FunctionName string `json:"name"`
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponseContent struct {
|
||||
Name string `json:"name"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response GeminiFunctionResponseContent `json:"response"`
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type GeminiPartExecutableCode struct {
|
||||
@@ -54,6 +75,7 @@ type GeminiFileData struct {
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
@@ -62,6 +84,33 @@ type GeminiPart struct {
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
|
||||
func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
// Alias to avoid recursion during unmarshalling
|
||||
type Alias GeminiPart
|
||||
var aux struct {
|
||||
Alias
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign fields from alias
|
||||
*p = GeminiPart(aux.Alias)
|
||||
|
||||
// Prioritize snake_case for InlineData if present
|
||||
if aux.InlineDataSnake != nil {
|
||||
p.InlineData = aux.InlineDataSnake
|
||||
} else if aux.InlineData != nil { // Fallback to camelCase from Alias
|
||||
p.InlineData = aux.InlineData
|
||||
}
|
||||
// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
@@ -91,6 +140,7 @@ type GeminiChatGenerationConfig struct {
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
@@ -116,10 +166,16 @@ type GeminiChatResponse struct {
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiPromptTokensDetails struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// Imagen related structs
|
||||
|
||||
146
relay/channel/gemini/relay-gemini-native.go
Normal file
146
relay/channel/gemini/relay-gemini-native.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := json.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 设置响应头并写入响应
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.StringData(c, data)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -12,12 +12,72 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var geminiSupportedMimeTypes = map[string]bool{
|
||||
"application/pdf": true,
|
||||
"audio/mpeg": true,
|
||||
"audio/mp3": true,
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
"video/mp4": true,
|
||||
"video/mpg": true,
|
||||
"video/avi": true,
|
||||
"video/wmv": true,
|
||||
"video/mpegps": true,
|
||||
"video/flv": true,
|
||||
}
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
flash25LiteMinBudget = 512
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
return flash25LiteMinBudget
|
||||
}
|
||||
if budget > flash25LiteMaxBudget {
|
||||
return flash25LiteMaxBudget
|
||||
}
|
||||
} else if isNew25Pro {
|
||||
if budget < pro25MinBudget {
|
||||
return pro25MinBudget
|
||||
}
|
||||
if budget > pro25MaxBudget {
|
||||
return pro25MaxBudget
|
||||
}
|
||||
} else { // 其他模型
|
||||
if budget < 0 {
|
||||
return 0
|
||||
}
|
||||
if budget > flash25MaxBudget {
|
||||
return flash25MaxBudget
|
||||
}
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
@@ -39,18 +99,54 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
IncludeThoughts: true,
|
||||
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,12 +208,6 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
} else if textRequest.Functions != nil {
|
||||
//geminiRequest.Tools = []GeminiChatTool{
|
||||
// {
|
||||
// FunctionDeclarations: textRequest.Functions,
|
||||
// },
|
||||
//}
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
@@ -148,17 +238,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
|
||||
name = val
|
||||
}
|
||||
content := common.StrToMap(message.StringContent())
|
||||
var contentMap map[string]interface{}
|
||||
contentStr := message.StringContent()
|
||||
|
||||
// 1. 尝试解析为 JSON 对象
|
||||
if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
|
||||
// 2. 如果失败,尝试解析为 JSON 数组
|
||||
var contentSlice []interface{}
|
||||
if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
|
||||
// 如果是数组,包装成对象
|
||||
contentMap = map[string]interface{}{"result": contentSlice}
|
||||
} else {
|
||||
// 3. 如果再次失败,作为纯文本处理
|
||||
contentMap = map[string]interface{}{"content": contentStr}
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
Name: name,
|
||||
Response: GeminiFunctionResponseContent{
|
||||
Name: name,
|
||||
Content: content,
|
||||
},
|
||||
}
|
||||
if content == nil {
|
||||
functionResp.Response.Content = message.StringContent()
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
@@ -208,14 +308,21 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
// 是url,获取文件的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
||||
}
|
||||
|
||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||
url := part.GetImageMedia().Url
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: fileData.MimeType,
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
})
|
||||
@@ -249,13 +356,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.GetInputAudio().Data == "" {
|
||||
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
||||
}
|
||||
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
|
||||
base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: format,
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
@@ -268,7 +375,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
if len(content.Parts) > 0 {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
@@ -284,100 +393,126 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
// Helper function to get a list of supported MIME types for error messages
|
||||
func getSupportedMimeTypesList() []string {
|
||||
keys := make([]string, 0, len(geminiSupportedMimeTypes))
|
||||
for k := range geminiSupportedMimeTypes {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
|
||||
func cleanFunctionParameters(params interface{}) interface{} {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
paramMap, ok := params.(map[string]interface{})
|
||||
if !ok {
|
||||
// Not a map, return as is (e.g., could be an array or primitive)
|
||||
return params
|
||||
}
|
||||
switch v := params.(type) {
|
||||
case map[string]interface{}:
|
||||
// Create a copy to avoid modifying the original
|
||||
cleanedMap := make(map[string]interface{})
|
||||
for k, val := range v {
|
||||
cleanedMap[k] = val
|
||||
}
|
||||
|
||||
// Create a copy to avoid modifying the original
|
||||
cleanedMap := make(map[string]interface{})
|
||||
for k, v := range paramMap {
|
||||
cleanedMap[k] = v
|
||||
}
|
||||
// Remove unsupported root-level fields
|
||||
delete(cleanedMap, "default")
|
||||
delete(cleanedMap, "exclusiveMaximum")
|
||||
delete(cleanedMap, "exclusiveMinimum")
|
||||
delete(cleanedMap, "$schema")
|
||||
delete(cleanedMap, "additionalProperties")
|
||||
|
||||
// Remove unsupported root-level fields
|
||||
delete(cleanedMap, "default")
|
||||
delete(cleanedMap, "exclusiveMaximum")
|
||||
delete(cleanedMap, "exclusiveMinimum")
|
||||
delete(cleanedMap, "$schema")
|
||||
delete(cleanedMap, "additionalProperties")
|
||||
|
||||
// Clean properties
|
||||
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
||||
cleanedProps := make(map[string]interface{})
|
||||
for propName, propValue := range props {
|
||||
propMap, ok := propValue.(map[string]interface{})
|
||||
if !ok {
|
||||
cleanedProps[propName] = propValue // Keep non-map properties
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a copy of the property map
|
||||
cleanedPropMap := make(map[string]interface{})
|
||||
for k, v := range propMap {
|
||||
cleanedPropMap[k] = v
|
||||
}
|
||||
|
||||
// Remove unsupported fields
|
||||
delete(cleanedPropMap, "default")
|
||||
delete(cleanedPropMap, "exclusiveMaximum")
|
||||
delete(cleanedPropMap, "exclusiveMinimum")
|
||||
delete(cleanedPropMap, "$schema")
|
||||
delete(cleanedPropMap, "additionalProperties")
|
||||
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
|
||||
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
|
||||
if formatValue != "enum" && formatValue != "date-time" {
|
||||
delete(cleanedPropMap, "format")
|
||||
}
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
|
||||
if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
|
||||
if formatValue != "enum" && formatValue != "date-time" {
|
||||
delete(cleanedMap, "format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively clean nested properties within this property if it's an object/array
|
||||
// Check the type before recursing
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
|
||||
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
|
||||
} else {
|
||||
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
|
||||
// Clean properties
|
||||
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
||||
cleanedProps := make(map[string]interface{})
|
||||
for propName, propValue := range props {
|
||||
cleanedProps[propName] = cleanFunctionParameters(propValue)
|
||||
}
|
||||
|
||||
cleanedMap["properties"] = cleanedProps
|
||||
}
|
||||
cleanedMap["properties"] = cleanedProps
|
||||
}
|
||||
|
||||
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
|
||||
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
||||
cleanedMap["items"] = cleanFunctionParameters(items)
|
||||
}
|
||||
// Also handle items if it's an array of schemas
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
||||
cleanedItemsArray := make([]interface{}, len(itemsArray))
|
||||
for i, item := range itemsArray {
|
||||
cleanedItemsArray[i] = cleanFunctionParameters(item)
|
||||
// Recursively clean items in arrays
|
||||
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
||||
cleanedMap["items"] = cleanFunctionParameters(items)
|
||||
}
|
||||
cleanedMap["items"] = cleanedItemsArray
|
||||
}
|
||||
|
||||
// Recursively clean other schema composition keywords if necessary
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := cleanedMap[field].([]interface{}); ok {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParameters(item)
|
||||
// Also handle items if it's an array of schemas
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
||||
cleanedItemsArray := make([]interface{}, len(itemsArray))
|
||||
for i, item := range itemsArray {
|
||||
cleanedItemsArray[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap[field] = cleanedNested
|
||||
cleanedMap["items"] = cleanedItemsArray
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedMap
|
||||
// Recursively clean other schema composition keywords
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := cleanedMap[field].([]interface{}); ok {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap[field] = cleanedNested
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively clean patternProperties
|
||||
if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
|
||||
cleanedPatternProps := make(map[string]interface{})
|
||||
for pattern, schema := range patternProps {
|
||||
cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
|
||||
}
|
||||
cleanedMap["patternProperties"] = cleanedPatternProps
|
||||
}
|
||||
|
||||
// Recursively clean definitions
|
||||
if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
|
||||
cleanedDefinitions := make(map[string]interface{})
|
||||
for defName, defSchema := range definitions {
|
||||
cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
|
||||
}
|
||||
cleanedMap["definitions"] = cleanedDefinitions
|
||||
}
|
||||
|
||||
// Recursively clean $defs (newer JSON Schema draft)
|
||||
if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
|
||||
cleanedDefs := make(map[string]interface{})
|
||||
for defName, defSchema := range defs {
|
||||
cleanedDefs[defName] = cleanFunctionParameters(defSchema)
|
||||
}
|
||||
cleanedMap["$defs"] = cleanedDefs
|
||||
}
|
||||
|
||||
// Clean conditional keywords
|
||||
for _, field := range []string{"if", "then", "else", "not"} {
|
||||
if nested, ok := cleanedMap[field]; ok {
|
||||
cleanedMap[field] = cleanFunctionParameters(nested)
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedMap
|
||||
|
||||
case []interface{}:
|
||||
// Handle arrays of schemas
|
||||
cleanedArray := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
cleanedArray[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
return cleanedArray
|
||||
|
||||
default:
|
||||
// Not a map or array, return as is (e.g., could be a primitive)
|
||||
return params
|
||||
}
|
||||
}
|
||||
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
@@ -512,21 +647,20 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
isToolCall := false
|
||||
for _, candidate := range response.Candidates {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
@@ -539,6 +673,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
toolCalls = append(toolCalls, *call)
|
||||
}
|
||||
} else if part.Thought {
|
||||
choice.Message.ReasoningContent = part.Text
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
|
||||
@@ -556,7 +692,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
choice.Message.SetToolCalls(toolCalls)
|
||||
isToolCall = true
|
||||
}
|
||||
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
|
||||
}
|
||||
@@ -596,6 +731,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
}
|
||||
var texts []string
|
||||
isTools := false
|
||||
isThought := false
|
||||
if candidate.FinishReason != nil {
|
||||
// p := GeminiConvertFinishReason(*candidate.FinishReason)
|
||||
switch *candidate.FinishReason {
|
||||
@@ -620,6 +756,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
|
||||
}
|
||||
} else if part.Thought {
|
||||
isThought = true
|
||||
texts = append(texts, part.Text)
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
|
||||
@@ -632,7 +771,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
}
|
||||
}
|
||||
}
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
if isThought {
|
||||
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
|
||||
} else {
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
}
|
||||
if isTools {
|
||||
choice.FinishReason = &constant.FinishReasonToolCalls
|
||||
}
|
||||
@@ -647,7 +790,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
@@ -672,6 +815,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
@@ -716,8 +866,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -732,7 +885,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
@@ -743,6 +896,14 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,13 +1,55 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
|
||||
|
||||
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
idMap := make(map[string]string)
|
||||
for _, message := range request.Messages {
|
||||
// 1. tool_calls.id
|
||||
toolCalls := message.ParseToolCalls()
|
||||
if toolCalls != nil {
|
||||
for i := range toolCalls {
|
||||
if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
|
||||
if newId, ok := idMap[toolCalls[i].ID]; ok {
|
||||
toolCalls[i].ID = newId
|
||||
} else {
|
||||
newId, err := common.GenerateRandomCharsKey(9)
|
||||
if err == nil {
|
||||
idMap[toolCalls[i].ID] = newId
|
||||
toolCalls[i].ID = newId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
message.SetToolCalls(toolCalls)
|
||||
}
|
||||
|
||||
// 2. tool_call_id
|
||||
if message.ToolCallId != "" {
|
||||
if newId, ok := idMap[message.ToolCallId]; ok {
|
||||
message.ToolCallId = newId
|
||||
} else {
|
||||
if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
|
||||
newId, err := common.GenerateRandomCharsKey(9)
|
||||
if err == nil {
|
||||
idMap[message.ToolCallId] = newId
|
||||
message.ToolCallId = newId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mediaMessages := message.ParseContent()
|
||||
if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
|
||||
mediaMessages = []dto.MediaContent{}
|
||||
}
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
|
||||
@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == constant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
// 2025年5月10日后创建的渠道不移除.
|
||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
@@ -180,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
@@ -215,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
forceFormat := false
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
@@ -224,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -273,36 +274,25 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
}
|
||||
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer resp.Body.Close()
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
c.Writer.WriteHeaderNow()
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -356,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*")
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
@@ -376,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
||||
tempStr := responseTextBuilder.String()
|
||||
if len(tempStr) > 0 {
|
||||
// 非正常结束,使用输出文本的 token 数量
|
||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
usage.CompletionTokens = completionTokens
|
||||
}
|
||||
}
|
||||
|
||||
9
relay/channel/openrouter/dto.go
Normal file
9
relay/channel/openrouter/dto.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package openrouter
|
||||
|
||||
type RequestReasoning struct {
|
||||
// One of the following (not both):
|
||||
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
|
||||
// Optional: Default is false. All models support this.
|
||||
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -45,12 +44,11 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
content, _ := json.Marshal(candidate.Content)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: candidate.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
@@ -74,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
@@ -157,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
312
relay/channel/task/kling/adaptor.go
Normal file
312
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key,secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, ",")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("kling_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
token, err := a.createJWTToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("kling_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body := a.convertToRequestPayload(&req)
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt Kling response parse first.
|
||||
var kResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
||||
return kResp.Data.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
// Fallback generic task response.
|
||||
var generic dto.TaskResponse[string]
|
||||
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !generic.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||
return generic.Data, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := a.createJWTTokenWithKey(key)
|
||||
if err != nil {
|
||||
token = key
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "kling"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
||||
r := &requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
Model: req.Model,
|
||||
ModelName: req.Model,
|
||||
CfgScale: 0.5,
|
||||
}
|
||||
if r.Model == "" {
|
||||
r.Model = "kling-v1"
|
||||
r.ModelName = "kling-v1"
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1024x1024", "512x512":
|
||||
return "1:1"
|
||||
case "1280x720", "1920x1080":
|
||||
return "16:9"
|
||||
case "720x1280", "1080x1920":
|
||||
return "9:16"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func defaultString(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultInt(v int, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ============================
|
||||
// JWT helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
parts := strings.Split(apiKey, ",")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||
}
|
||||
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return "", fmt.Errorf("access key and secret key are required")
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": accessKey,
|
||||
"exp": now + 1800, // 30 minutes
|
||||
"nbf": now - 5,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["typ"] = "JWT"
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
// ParseResultUrl 提取视频任务结果的 url
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("data field not found or invalid")
|
||||
}
|
||||
taskResult, ok := data["task_result"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task_result field not found or invalid")
|
||||
}
|
||||
videos, ok := taskResult["videos"].([]interface{})
|
||||
if !ok || len(videos) == 0 {
|
||||
return "", fmt.Errorf("videos field not found or empty")
|
||||
}
|
||||
video, ok := videos[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("video item invalid")
|
||||
}
|
||||
url, ok := video["url"].(string)
|
||||
if !ok || url == "" {
|
||||
return "", fmt.Errorf("url field not found or invalid")
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
return "", nil // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
}
|
||||
|
||||
@@ -56,12 +56,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
|
||||
},
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
|
||||
@@ -31,6 +32,8 @@ var claudeModelMap = map[string]string{
|
||||
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
|
||||
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
|
||||
"claude-opus-4-20250514": "claude-opus-4@20250514",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
@@ -80,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
@@ -93,14 +99,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
@@ -111,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
@@ -190,7 +214,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
}
|
||||
@@ -199,7 +227,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the multipart form to handle both single image and multiple images
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
|
||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||
// Check if "image" field exists in any form, including array notation
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range c.Request.MultipartForm.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
for _, file := range files {
|
||||
imageFiles = append(imageFiles, file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all image files
|
||||
for i, fileHeader := range imageFiles {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// If multiple images, use image[] as the field name
|
||||
fieldName := "image"
|
||||
if len(imageFiles) > 1 {
|
||||
fieldName = "image[]"
|
||||
}
|
||||
|
||||
// Determine MIME type based on file extension
|
||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
part, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mask file if present
|
||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
maskFile, err := maskFiles[0].Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Determine MIME type for mask file
|
||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
maskPart, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed for mask")
|
||||
}
|
||||
|
||||
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
||||
return nil, errors.New("copy mask file failed")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("no multipart form data found")
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil
|
||||
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type based on the file extension
|
||||
func detectImageMimeType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
// Try to detect from extension if possible
|
||||
if strings.HasPrefix(ext, ".jp") {
|
||||
return "image/jpeg"
|
||||
}
|
||||
// Default to png as a fallback
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user