diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-alpha.yml
similarity index 72%
rename from .github/workflows/docker-image-amd64.yml
rename to .github/workflows/docker-image-alpha.yml
index a823151c..c02bd409 100644
--- a/.github/workflows/docker-image-amd64.yml
+++ b/.github/workflows/docker-image-alpha.yml
@@ -1,14 +1,15 @@
-name: Publish Docker image (amd64)
+name: Publish Docker image (alpha)
on:
push:
- tags:
- - '*'
+ branches:
+ - alpha
workflow_dispatch:
inputs:
name:
- description: 'reason'
+ description: "reason"
required: false
+
jobs:
push_to_registries:
name: Push Docker image to multiple registries
@@ -22,7 +23,7 @@ jobs:
- name: Save version info
run: |
- git describe --tags > VERSION
+ echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v3
@@ -37,6 +38,9 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
@@ -44,11 +48,15 @@ jobs:
images: |
calciumion/new-api
ghcr.io/${{ github.repository }}
+ tags: |
+ type=raw,value=alpha
+ type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
- name: Build and push Docker images
uses: docker/build-push-action@v5
with:
context: .
+ platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels }}
\ No newline at end of file
+ labels: ${{ steps.meta.outputs.labels }}
diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml
index d7468c8e..8e4656aa 100644
--- a/.github/workflows/docker-image-arm64.yml
+++ b/.github/workflows/docker-image-arm64.yml
@@ -1,14 +1,9 @@
-name: Publish Docker image (arm64)
+name: Publish Docker image (Multi Registries)
on:
push:
tags:
- '*'
- workflow_dispatch:
- inputs:
- name:
- description: 'reason'
- required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml
index 3ddabc6d..c87fcfce 100644
--- a/.github/workflows/linux-release.yml
+++ b/.github/workflows/linux-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -15,16 +20,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml
index ccc480bf..3210065b 100644
--- a/.github/workflows/macos-release.yml
+++ b/.github/workflows/macos-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -15,16 +20,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml
index f9500718..de3d83d5 100644
--- a/.github/workflows/windows-release.yml
+++ b/.github/workflows/windows-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -18,16 +23,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/Dockerfile b/Dockerfile
index 214ceaa3..3b42089b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
FROM alpine
-RUN apk update \
- && apk upgrade \
+RUN apk upgrade --no-cache \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates
diff --git a/README.en.md b/README.en.md
index 4709bc5b..10a3cdb0 100644
--- a/README.en.md
+++ b/README.en.md
@@ -44,6 +44,9 @@
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
+You can also access the AI-generated DeepWiki:
+[](https://deepwiki.com/QuantumNous/new-api)
+
## ✨ Key Features
New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
@@ -110,6 +113,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
+- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
## Deployment
diff --git a/README.md b/README.md
index a807b07d..6ba3574c 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,9 @@
+
+
+
@@ -44,6 +47,9 @@
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
+也可访问AI生成的DeepWiki:
+[](https://deepwiki.com/QuantumNous/new-api)
+
## ✨ 主要特性
New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
@@ -110,6 +116,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
+- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
## 部署
@@ -176,7 +183,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
其他基于New API的项目:
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
-- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
## 帮助支持
diff --git a/common/constants.go b/common/constants.go
index bee00506..ac803148 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -241,6 +241,7 @@ const (
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
+ ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
+ "https://api.klingai.com", //50
}
diff --git a/common/database.go b/common/database.go
index 3c0a944b..9cbaf46a 100644
--- a/common/database.go
+++ b/common/database.go
@@ -1,7 +1,14 @@
package common
+const (
+ DatabaseTypeMySQL = "mysql"
+ DatabaseTypeSQLite = "sqlite"
+ DatabaseTypePostgreSQL = "postgres"
+)
+
var UsingSQLite = false
var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false
var UsingClickHouse = false
diff --git a/common/redis.go b/common/redis.go
index 49d3ec78..1efc217f 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -92,12 +92,12 @@ func RedisDel(key string) error {
return RDB.Del(ctx, key).Err()
}
-func RedisHDelObj(key string) error {
+func RedisDelKey(key string) error {
if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
+ SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
}
ctx := context.Background()
- return RDB.HDel(ctx, key).Err()
+ return RDB.Del(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
- txn.Expire(ctx, key, expiration)
+
+ // 只有在 expiration 大于 0 时才设置过期时间
+ if expiration > 0 {
+ txn.Expire(ctx, key, expiration)
+ }
_, err := txn.Exec(ctx)
if err != nil {
diff --git a/common/utils.go b/common/utils.go
index 587de537..17aecd95 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -13,6 +13,7 @@ import (
"math/big"
"math/rand"
"net"
+ "net/url"
"os"
"os/exec"
"runtime"
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
+ durationStr := string(bytes.TrimSpace(output))
+ if durationStr == "N/A" {
+ // Create a temporary output file name
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to create temporary file")
+ }
+ tmpName := tmpFp.Name()
+ // Close immediately so ffmpeg can open the file on Windows.
+ _ = tmpFp.Close()
+ defer os.Remove(tmpName)
- return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+ // ffmpeg -y -i filename -vcodec copy -acodec copy
+ ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+ if err := ffmpegCmd.Run(); err != nil {
+ return 0, errors.Wrap(err, "failed to run ffmpeg")
+ }
+
+ // Recalculate the duration of the new file
+ c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+ output, err := c.Output()
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+ }
+ durationStr = string(bytes.TrimSpace(output))
+ }
+ return strconv.ParseFloat(durationStr, 64)
+}
+
+// BuildURL concatenates base and endpoint, returns the complete url string
+func BuildURL(base string, endpoint string) string {
+ u, err := url.Parse(base)
+ if err != nil {
+ return base + endpoint
+ }
+ end := endpoint
+ if end == "" {
+ end = "/"
+ }
+ ref, err := url.Parse(end)
+ if err != nil {
+ return base + endpoint
+ }
+ return u.ResolveReference(ref).String()
}
diff --git a/constant/cache_key.go b/constant/cache_key.go
index 27cb3b75..daedfd40 100644
--- a/constant/cache_key.go
+++ b/constant/cache_key.go
@@ -2,12 +2,10 @@ package constant
import "one-api/common"
-var (
- TokenCacheSeconds = common.SyncFrequency
- UserId2GroupCacheSeconds = common.SyncFrequency
- UserId2QuotaCacheSeconds = common.SyncFrequency
- UserId2StatusCacheSeconds = common.SyncFrequency
-)
+// 使用函数来避免初始化顺序带来的赋值问题
+func RedisKeyCacheSeconds() int {
+ return common.SyncFrequency
+}
// Cache keys
const (
diff --git a/constant/task.go b/constant/task.go
index 1a68b812..d466fc8a 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -5,6 +5,7 @@ type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
+ TaskPlatformKling TaskPlatform = "kling"
)
const (
diff --git a/constant/user_setting.go b/constant/user_setting.go
index 055884f7..7e79035e 100644
--- a/constant/user_setting.go
+++ b/constant/user_setting.go
@@ -7,6 +7,7 @@ var (
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
+ UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
)
var (
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 2bda0fd2..9bf5d1fe 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -4,11 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/shopspring/decimal"
"io"
"net/http"
"one-api/common"
"one-api/model"
"one-api/service"
+ "one-api/setting"
"strconv"
"time"
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
return balance, nil
}
+func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.moonshot.cn/v1/users/me/balance"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+
+ type MoonshotBalanceData struct {
+ AvailableBalance float64 `json:"available_balance"`
+ VoucherBalance float64 `json:"voucher_balance"`
+ CashBalance float64 `json:"cash_balance"`
+ }
+
+ type MoonshotBalanceResponse struct {
+ Code int `json:"code"`
+ Data MoonshotBalanceData `json:"data"`
+ Scode string `json:"scode"`
+ Status bool `json:"status"`
+ }
+
+ response := MoonshotBalanceResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if !response.Status || response.Code != 0 {
+ return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
+ }
+ availableBalanceCny := response.Data.AvailableBalance
+ availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+ channel.UpdateBalance(availableBalanceUsd)
+ return availableBalanceUsd, nil
+}
+
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
+ case common.ChannelTypeMoonshot:
+ return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index d1cb4093..d54ccf0d 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
+ if channel.Type == common.ChannelTypeKling {
+ return errors.New("kling channel test is not supported"), nil
+ }
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c)
- err = helper.ModelMappedHelper(c, info)
+ err = helper.ModelMappedHelper(c, info, nil)
if err != nil {
return err, nil
}
@@ -165,8 +168,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
- other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
- usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -200,10 +203,10 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
} else {
testRequest.MaxTokens = 10
}
- content, _ := json.Marshal("hi")
+
testMessage := dto.Message{
Role: "user",
- Content: content,
+ Content: "hi",
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
@@ -271,6 +274,13 @@ func testAllChannels(notify bool) error {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
+ // 使用 defer 确保无论如何都会重置运行状态,防止死锁
+ defer func() {
+ testAllChannelsLock.Lock()
+ testAllChannelsRunning = false
+ testAllChannelsLock.Unlock()
+ }()
+
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
@@ -305,9 +315,7 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
- testAllChannelsLock.Lock()
- testAllChannelsRunning = false
- testAllChannelsLock.Unlock()
+
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
diff --git a/controller/channel.go b/controller/channel.go
index ad85fe24..13ed72b3 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -43,22 +43,31 @@ type OpenAIModelsResponse struct {
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
- if pageSize < 0 {
+ if pageSize < 1 {
pageSize = common.ItemsPerPage
}
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+ // type filter
+ typeStr := c.Query("type")
+ typeFilter := -1
+ if typeStr != "" {
+ if t, err := strconv.Atoi(typeStr); err == nil {
+ typeFilter = t
+ }
+ }
+
+ var total int64
+
if enableTagMode {
- tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
+ // tag 分页:先分页 tag,再取各 tag 下 channels
+ tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
@@ -69,21 +78,39 @@ func GetAllChannels(c *gin.Context) {
}
}
}
- } else {
- channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
+ // 计算 tag 总数用于分页
+ total, _ = model.CountAllTags()
+ } else if typeFilter >= 0 {
+ channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
+ total, _ = model.CountChannelsByType(typeFilter)
+ } else {
+ channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ channelData = channels
+ total, _ = model.CountAllChannels()
}
+
+ // calculate type counts
+ typeCounts, _ := model.CountChannelsGroupByType()
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": channelData,
+ "data": gin.H{
+ "items": channelData,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ "type_counts": typeCounts,
+ },
})
return
}
@@ -119,8 +146,11 @@ func FetchUpstreamModels(c *gin.Context) {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
- if channel.Type == common.ChannelTypeGemini {
+ switch channel.Type {
+ case common.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
+ case common.ChannelTypeAli:
+ url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
@@ -207,10 +237,20 @@ func SearchChannels(c *gin.Context) {
}
channelData = channels
}
+
+ // calculate type counts for search results
+ typeCounts := make(map[int64]int64)
+ for _, channel := range channelData {
+ typeCounts[int64(channel.Type)]++
+ }
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": channelData,
+ "data": gin.H{
+ "items": channelData,
+ "type_counts": typeCounts,
+ },
})
return
}
@@ -620,3 +660,44 @@ func BatchSetChannelTag(c *gin.Context) {
})
return
}
+
+func GetTagModels(c *gin.Context) {
+ tag := c.Query("tag")
+ if tag == "" {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "message": "tag不能为空",
+ })
+ return
+ }
+
+ channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var longestModels string
+ maxLength := 0
+
+ // Find the longest models string among all channels with the given tag
+ for _, channel := range channels {
+ if channel.Models != "" {
+ currentModels := strings.Split(channel.Models, ",")
+ if len(currentModels) > maxLength {
+ maxLength = len(currentModels)
+ longestModels = channel.Models
+ }
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": longestModels,
+ })
+ return
+}
diff --git a/controller/console_migrate.go b/controller/console_migrate.go
new file mode 100644
index 00000000..d25f199b
--- /dev/null
+++ b/controller/console_migrate.go
@@ -0,0 +1,103 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+ "encoding/json"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+ // 读取全部 option
+ opts, err := model.AllOption()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // 建立 map
+ valMap := map[string]string{}
+ for _, o := range opts {
+ valMap[o.Key] = o.Value
+ }
+
+ // 处理 APIInfo
+ if v := valMap["ApiInfo"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ if len(arr) > 50 {
+ arr = arr[:50]
+ }
+ bytes, _ := json.Marshal(arr)
+ model.UpdateOption("console_setting.api_info", string(bytes))
+ }
+ model.UpdateOption("ApiInfo", "")
+ }
+ // Announcements 直接搬
+ if v := valMap["Announcements"]; v != "" {
+ model.UpdateOption("console_setting.announcements", v)
+ model.UpdateOption("Announcements", "")
+ }
+ // FAQ 转换
+ if v := valMap["FAQ"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ out := []map[string]interface{}{}
+ for _, item := range arr {
+ q, _ := item["question"].(string)
+ if q == "" {
+ q, _ = item["title"].(string)
+ }
+ a, _ := item["answer"].(string)
+ if a == "" {
+ a, _ = item["content"].(string)
+ }
+ if q != "" && a != "" {
+ out = append(out, map[string]interface{}{"question": q, "answer": a})
+ }
+ }
+ if len(out) > 50 {
+ out = out[:50]
+ }
+ bytes, _ := json.Marshal(out)
+ model.UpdateOption("console_setting.faq", string(bytes))
+ }
+ model.UpdateOption("FAQ", "")
+ }
+ // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+ url := valMap["UptimeKumaUrl"]
+ slug := valMap["UptimeKumaSlug"]
+ if url != "" && slug != "" {
+ // 仅当同时存在 URL 与 Slug 时才进行迁移
+ groups := []map[string]interface{}{
+ {
+ "id": 1,
+ "categoryName": "old",
+ "url": url,
+ "slug": slug,
+ "description": "",
+ },
+ }
+ bytes, _ := json.Marshal(groups)
+ model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+ }
+ // 清空旧键内容
+ if url != "" {
+ model.UpdateOption("UptimeKumaUrl", "")
+ }
+ if slug != "" {
+ model.UpdateOption("UptimeKumaSlug", "")
+ }
+
+ // 删除旧键记录
+ oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+ model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+ // 重新加载 OptionMap
+ model.InitOptionMap()
+ common.SysLog("console setting migrated")
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}
\ No newline at end of file
diff --git a/controller/group.go b/controller/group.go
index 2c725a4d..2565b6ea 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -1,15 +1,17 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
- for groupName, _ := range setting.GetGroupRatioCopy() {
+ for groupName := range ratio_setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
- for groupName, ratio := range setting.GetGroupRatioCopy() {
+ for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 21027d8f..56bdcb80 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -7,7 +7,6 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"io"
- "log"
"net/http"
"one-api/common"
"one-api/dto"
@@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
// 解析其他查询参数
@@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
+ total := model.CountAllTasks(queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
userId := c.GetInt("id")
- log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
@@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+ total := model.CountAllUserTask(userId, queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
diff --git a/controller/misc.go b/controller/misc.go
index 4d265c3f..4ffe86f4 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -6,8 +6,10 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/middleware"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
@@ -24,57 +26,85 @@ func TestStatus(c *gin.Context) {
})
return
}
+ // 获取HTTP统计信息
+ httpStats := middleware.GetStats()
c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Server is running",
+ "success": true,
+ "message": "Server is running",
+ "http_stats": httpStats,
})
return
}
func GetStatus(c *gin.Context) {
+
+ cs := console_setting.GetConsoleSetting()
+
+ data := gin.H{
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "min_topup": setting.MinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
+ "pay_methods": setting.PayMethods,
+
+ // 面板启用开关
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
+
+ "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
+ "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
+ "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+ "setup": constant.Setup,
+ }
+
+ // 根据启用状态注入可选内容
+ if cs.ApiInfoEnabled {
+ data["api_info"] = console_setting.GetApiInfo()
+ }
+ if cs.AnnouncementsEnabled {
+ data["announcements"] = console_setting.GetAnnouncements()
+ }
+ if cs.FAQEnabled {
+ data["faq"] = console_setting.GetFAQ()
+ }
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "min_topup": setting.MinTopUp,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
- "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
- "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
- "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
- "setup": constant.Setup,
- },
+ "data": data,
})
return
}
diff --git a/controller/model.go b/controller/model.go
index df7e59a6..78bd32d6 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -2,7 +2,7 @@ package controller
import (
"fmt"
- "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +15,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -134,6 +137,9 @@ func init() {
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
+ openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
+ return m.Id
+ })
}
func ListModels(c *gin.Context) {
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
- models := model.GetGroupModels(group)
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupModels(group)
+ }
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
diff --git a/controller/option.go b/controller/option.go
index 250f16bb..97bb6a5a 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -6,6 +6,8 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
+ "one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strings"
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
- err = setting.CheckGroupRatio(option.Value)
+ err = ratio_setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -119,7 +121,42 @@ func UpdateOption(c *gin.Context) {
})
return
}
-
+ case "console_setting.api_info":
+ err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.announcements":
+ err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.faq":
+ err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.uptime_kuma_groups":
+ err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {
diff --git a/controller/playground.go b/controller/playground.go
index a2b54790..10393250 100644
--- a/controller/playground.go
+++ b/controller/playground.go
@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
+
+ "github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
- channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
+ channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
+
+ // Write user context to ensure acceptUnsetRatio is available
+ userId := c.GetInt("id")
+ userCache, err := model.GetUserCache(userId)
+ if err != nil {
+ openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
+ return
+ }
+ userCache.WriteContext(c)
Relay(c)
}
diff --git a/controller/pricing.go b/controller/pricing.go
index 1cbfe731..f27336b7 100644
--- a/controller/pricing.go
+++ b/controller/pricing.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"one-api/model"
"one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
func GetPricing(c *gin.Context) {
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
- for s, f := range setting.GetGroupRatioCopy() {
+ for s, f := range ratio_setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
@@ -20,12 +21,18 @@ func GetPricing(c *gin.Context) {
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
+ for g := range groupRatio {
+ ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
+ if ok {
+ groupRatio[g] = ratio
+ }
+ }
}
}
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
- for group := range setting.GetGroupRatioCopy() {
+ for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
@@ -40,7 +47,7 @@ func GetPricing(c *gin.Context) {
}
func ResetModelRatio(c *gin.Context) {
- defaultStr := operation_setting.DefaultModelRatio2JSONString()
+ defaultStr := ratio_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -49,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
- err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
+ err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
diff --git a/controller/ratio_config.go b/controller/ratio_config.go
new file mode 100644
index 00000000..6ddc3d9e
--- /dev/null
+++ b/controller/ratio_config.go
@@ -0,0 +1,24 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetRatioConfig(c *gin.Context) {
+ if !ratio_setting.IsExposeRatioEnabled() {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "倍率配置接口未启用",
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": ratio_setting.GetExposedData(),
+ })
+}
\ No newline at end of file
diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go
new file mode 100644
index 00000000..f749f384
--- /dev/null
+++ b/controller/ratio_sync.go
@@ -0,0 +1,322 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ defaultTimeoutSeconds = 10
+ defaultEndpoint = "/api/ratio_config"
+ maxConcurrentFetches = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
+type upstreamResult struct {
+ Name string `json:"name"`
+ Data map[string]any `json:"data,omitempty"`
+ Err string `json:"err,omitempty"`
+}
+
+func FetchUpstreamRatios(c *gin.Context) {
+ var req dto.UpstreamRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ if req.Timeout <= 0 {
+ req.Timeout = defaultTimeoutSeconds
+ }
+
+ var upstreams []dto.UpstreamDTO
+
+ if len(req.ChannelIDs) > 0 {
+ intIds := make([]int, 0, len(req.ChannelIDs))
+ for _, id64 := range req.ChannelIDs {
+ intIds = append(intIds, int(id64))
+ }
+ dbChannels, err := model.GetChannelsByIds(intIds)
+ if err != nil {
+ common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+ return
+ }
+ for _, ch := range dbChannels {
+ if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+ upstreams = append(upstreams, dto.UpstreamDTO{
+ Name: ch.Name,
+ BaseURL: strings.TrimRight(base, "/"),
+ Endpoint: "",
+ })
+ }
+ }
+ }
+
+ if len(upstreams) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+ return
+ }
+
+ var wg sync.WaitGroup
+ ch := make(chan upstreamResult, len(upstreams))
+
+ sem := make(chan struct{}, maxConcurrentFetches)
+
+ client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+ for _, chn := range upstreams {
+ wg.Add(1)
+ go func(chItem dto.UpstreamDTO) {
+ defer wg.Done()
+
+ sem <- struct{}{}
+ defer func() { <-sem }()
+
+ endpoint := chItem.Endpoint
+ if endpoint == "" {
+ endpoint = defaultEndpoint
+ } else if !strings.HasPrefix(endpoint, "/") {
+ endpoint = "/" + endpoint
+ }
+ fullURL := chItem.BaseURL + endpoint
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+ defer cancel()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+ if err != nil {
+ common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+ ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+ return
+ }
+
+ resp, err := client.Do(httpReq)
+ if err != nil {
+ common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+ ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
+ return
+ }
+ var body struct {
+ Success bool `json:"success"`
+ Data map[string]any `json:"data"`
+ Message string `json:"message"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+ common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+ return
+ }
+ if !body.Success {
+ ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
+ return
+ }
+ ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
+ }(chn)
+ }
+
+ wg.Wait()
+ close(ch)
+
+ localData := ratio_setting.GetExposedData()
+
+ var testResults []dto.TestResult
+ var successfulChannels []struct {
+ name string
+ data map[string]any
+ }
+
+ for r := range ch {
+ if r.Err != "" {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "error",
+ Error: r.Err,
+ })
+ } else {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "success",
+ })
+ successfulChannels = append(successfulChannels, struct {
+ name string
+ data map[string]any
+ }{name: r.Name, data: r.Data})
+ }
+ }
+
+ differences := buildDifferences(localData, successfulChannels)
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": gin.H{
+ "differences": differences,
+ "test_results": testResults,
+ },
+ })
+}
+
+func buildDifferences(localData map[string]any, successfulChannels []struct {
+ name string
+ data map[string]any
+}) map[string]map[string]dto.DifferenceItem {
+ differences := make(map[string]map[string]dto.DifferenceItem)
+
+ allModels := make(map[string]struct{})
+
+ for _, ratioType := range ratioTypes {
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ for modelName := range localRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for _, channel := range successfulChannels {
+ for _, ratioType := range ratioTypes {
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ for modelName := range upstreamRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for modelName := range allModels {
+ for _, ratioType := range ratioTypes {
+ var localValue interface{} = nil
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ if val, exists := localRatio[modelName]; exists {
+ localValue = val
+ }
+ }
+ }
+
+ upstreamValues := make(map[string]interface{})
+ hasUpstreamValue := false
+ hasDifference := false
+
+ for _, channel := range successfulChannels {
+ var upstreamValue interface{} = nil
+
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ if val, exists := upstreamRatio[modelName]; exists {
+ upstreamValue = val
+ hasUpstreamValue = true
+
+ if localValue != nil && localValue != val {
+ hasDifference = true
+ } else if localValue == val {
+ upstreamValue = "same"
+ }
+ }
+ }
+ if upstreamValue == nil && localValue == nil {
+ upstreamValue = "same"
+ }
+
+ if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+ hasDifference = true
+ }
+
+ upstreamValues[channel.name] = upstreamValue
+ }
+
+ shouldInclude := false
+
+ if localValue != nil {
+ if hasDifference {
+ shouldInclude = true
+ }
+ } else {
+ if hasUpstreamValue {
+ shouldInclude = true
+ }
+ }
+
+ if shouldInclude {
+ if differences[modelName] == nil {
+ differences[modelName] = make(map[string]dto.DifferenceItem)
+ }
+ differences[modelName][ratioType] = dto.DifferenceItem{
+ Current: localValue,
+ Upstreams: upstreamValues,
+ }
+ }
+ }
+ }
+
+ channelHasDiff := make(map[string]bool)
+ for _, ratioMap := range differences {
+ for _, item := range ratioMap {
+ for chName, val := range item.Upstreams {
+ if val != nil && val != "same" {
+ channelHasDiff[chName] = true
+ }
+ }
+ }
+ }
+
+ for modelName, ratioMap := range differences {
+ for ratioType, item := range ratioMap {
+ for chName := range item.Upstreams {
+ if !channelHasDiff[chName] {
+ delete(item.Upstreams, chName)
+ }
+ }
+ differences[modelName][ratioType] = item
+ }
+ }
+
+ return differences
+}
+
+func GetSyncableChannels(c *gin.Context) {
+ channels, err := model.GetAllChannels(0, 0, true, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var syncableChannels []dto.SyncableChannel
+ for _, channel := range channels {
+ if channel.GetBaseURL() != "" {
+ syncableChannels = append(syncableChannels, dto.SyncableChannel{
+ ID: channel.Id,
+ Name: channel.Name,
+ BaseURL: channel.GetBaseURL(),
+ Status: channel.Status,
+ })
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": syncableChannels,
+ })
+}
\ No newline at end of file
diff --git a/controller/redemption.go b/controller/redemption.go
index a7e09a8a..50620597 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -5,6 +5,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
+ "errors"
"github.com/gin-gonic/gin"
)
@@ -126,6 +127,10 @@ func AddRedemption(c *gin.Context) {
})
return
}
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
@@ -135,6 +140,7 @@ func AddRedemption(c *gin.Context) {
Key: key,
CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota,
+ ExpiredTime: redemption.ExpiredTime,
}
err = cleanRedemption.Insert()
if err != nil {
@@ -191,12 +197,18 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
- if statusOnly != "" {
- cleanRedemption.Status = redemption.Status
- } else {
+ if statusOnly == "" {
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
// If you add more fields, please also update redemption.Update()
cleanRedemption.Name = redemption.Name
cleanRedemption.Quota = redemption.Quota
+ cleanRedemption.ExpiredTime = redemption.ExpiredTime
+ }
+ if statusOnly != "" {
+ cleanRedemption.Status = redemption.Status
}
err = cleanRedemption.Update()
if err != nil {
@@ -213,3 +225,27 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
+
+func DeleteInvalidRedemption(c *gin.Context) {
+ rows, err := model.DeleteInvalidRedemptions()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": rows,
+ })
+ return
+}
+
+func validateExpiredTime(expired int64) error {
+ if expired != 0 && expired < common.GetTimestamp() {
+ return errors.New("过期时间不能早于当前时间")
+ }
+ return nil
+}
diff --git a/controller/relay.go b/controller/relay.go
index 41cb22a5..4da4262b 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.EmbeddingHelper(c)
case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c)
+ case relayconstant.RelayModeGemini:
+ err = relay.GeminiHelper(c)
default:
err = relay.TextHelper(c)
}
@@ -257,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -386,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
@@ -418,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
- case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+ case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
diff --git a/controller/setup.go b/controller/setup.go
index 0a13bcf9..8943a1a0 100644
--- a/controller/setup.go
+++ b/controller/setup.go
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
// If root doesn't exist, validate and create admin account
if !rootExists {
+ // Validate username length: max 12 characters to align with model.User validation
+ if len(req.Username) > 12 {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "用户名长度不能超过12个字符",
+ })
+ return
+ }
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
diff --git a/controller/task.go b/controller/task.go
index 65f79ead..f7523e87 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+ case constant.TaskPlatformKling:
+ _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
@@ -224,9 +226,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
+ }
+
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
@@ -237,24 +244,32 @@ func GetAllTask(c *gin.Context) {
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
+ ChannelID: c.Query("channel_id"),
}
- logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
+ items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
+ total := model.TaskCountAllTasks(queryParams)
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
userId := c.GetInt("id")
@@ -271,14 +286,17 @@ func GetUserTask(c *gin.Context) {
EndTimestamp: endTimestamp,
}
- logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
+ items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+ total := model.TaskCountAllUserTask(userId, queryParams)
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
diff --git a/controller/task_video.go b/controller/task_video.go
new file mode 100644
index 00000000..a2c2431d
--- /dev/null
+++ b/controller/task_video.go
@@ -0,0 +1,140 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "one-api/relay"
+ "one-api/relay/channel"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ cacheGetChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if errUpdate != nil {
+ common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ }
+ return fmt.Errorf("CacheGetChannel failed: %w", err)
+ }
+ adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
+ if adaptor == nil {
+ return fmt.Errorf("video adaptor not found")
+ }
+ for _, taskId := range taskIds {
+ if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+ baseURL := common.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+ "task_id": taskId,
+ })
+ if err != nil {
+ return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
+ }
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
+ }
+
+ var responseItem map[string]interface{}
+ err = json.Unmarshal(responseBody, &responseItem)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
+ return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
+ }
+
+ code, _ := responseItem["code"].(float64)
+ if code != 0 {
+ return fmt.Errorf("video task fetch failed for task %s", taskId)
+ }
+
+ data, ok := responseItem["data"].(map[string]interface{})
+ if !ok {
+ common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
+ return fmt.Errorf("video task data format error for task %s", taskId)
+ }
+
+ task := taskM[taskId]
+ if task == nil {
+ common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
+
+ if status, ok := data["task_status"].(string); ok {
+ switch status {
+ case "submitted", "queued":
+ task.Status = model.TaskStatusSubmitted
+ case "processing":
+ task.Status = model.TaskStatusInProgress
+ case "succeed":
+ task.Status = model.TaskStatusSuccess
+ task.Progress = "100%"
+ if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
+ task.FailReason = url
+ } else {
+ common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
+ }
+ case "failed":
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ if reason, ok := data["fail_reason"].(string); ok {
+ task.FailReason = reason
+ }
+ }
+ }
+
+ // If task failed, refund quota
+ if task.Status == model.TaskStatusFailure {
+ common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ quota := task.Quota
+ if quota != 0 {
+ if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+ common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ }
+
+ task.Data = responseBody
+ if err := task.Update(); err != nil {
+ common.SysError("UpdateVideoTask task error: " + err.Error())
+ }
+
+ return nil
+}
diff --git a/controller/token.go b/controller/token.go
index a8803279..c57552c0 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id")
p, _ := strconv.Atoi(c.Query("p"))
size, _ := strconv.Atoi(c.Query("size"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
if size <= 0 {
size = common.ItemsPerPage
} else if size > 100 {
size = 100
}
- tokens, err := model.GetAllUserTokens(userId, p*size, size)
+ tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) {
})
return
}
+ // Get total count for pagination
+ total, _ := model.CountUserTokens(userId)
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": tokens,
+ "data": gin.H{
+ "items": tokens,
+ "total": total,
+ "page": p,
+ "page_size": size,
+ },
})
return
}
diff --git a/controller/topup.go b/controller/topup.go
index 4654b6ea..827dda39 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -97,16 +97,14 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
- payType := "wxpay"
- if req.PaymentMethod == "zfb" {
- payType = "alipay"
- }
- if req.PaymentMethod == "wx" {
- req.PaymentMethod = "wxpay"
- payType = "wxpay"
+
+ if !setting.ContainsPayMethod(req.PaymentMethod) {
+ c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
+ return
}
+
callBackAddress := service.GetCallbackAddress()
- returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
+ returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
- Type: payType,
+ Type: req.PaymentMethod,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go
new file mode 100644
index 00000000..05d6297e
--- /dev/null
+++ b/controller/uptime_kuma.go
@@ -0,0 +1,154 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "one-api/setting/console_setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ requestTimeout = 30 * time.Second
+ httpTimeout = 10 * time.Second
+ uptimeKeySuffix = "_24"
+ apiStatusPath = "/api/status-page/"
+ apiHeartbeatPath = "/api/status-page/heartbeat/"
+)
+
+type Monitor struct {
+ Name string `json:"name"`
+ Uptime float64 `json:"uptime"`
+ Status int `json:"status"`
+ Group string `json:"group,omitempty"`
+}
+
+type UptimeGroupResult struct {
+ CategoryName string `json:"categoryName"`
+ Monitors []Monitor `json:"monitors"`
+}
+
+func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return errors.New("non-200 status")
+ }
+
+ return json.NewDecoder(resp.Body).Decode(dest)
+}
+
+func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
+ url, _ := groupConfig["url"].(string)
+ slug, _ := groupConfig["slug"].(string)
+ categoryName, _ := groupConfig["categoryName"].(string)
+
+ result := UptimeGroupResult{
+ CategoryName: categoryName,
+ Monitors: []Monitor{},
+ }
+
+ if url == "" || slug == "" {
+ return result
+ }
+
+ baseURL := strings.TrimSuffix(url, "/")
+
+ var statusData struct {
+ PublicGroupList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ MonitorList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ } `json:"monitorList"`
+ } `json:"publicGroupList"`
+ }
+
+ var heartbeatData struct {
+ HeartbeatList map[string][]struct {
+ Status int `json:"status"`
+ } `json:"heartbeatList"`
+ UptimeList map[string]float64 `json:"uptimeList"`
+ }
+
+ g, gCtx := errgroup.WithContext(ctx)
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
+ })
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
+ })
+
+ if g.Wait() != nil {
+ return result
+ }
+
+ for _, pg := range statusData.PublicGroupList {
+ if len(pg.MonitorList) == 0 {
+ continue
+ }
+
+ for _, m := range pg.MonitorList {
+ monitor := Monitor{
+ Name: m.Name,
+ Group: pg.Name,
+ }
+
+ monitorID := strconv.Itoa(m.ID)
+
+ if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
+ monitor.Uptime = uptime
+ }
+
+ if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
+ monitor.Status = heartbeats[0].Status
+ }
+
+ result.Monitors = append(result.Monitors, monitor)
+ }
+ }
+
+ return result
+}
+
+func GetUptimeKumaStatus(c *gin.Context) {
+ groups := console_setting.GetUptimeKumaGroups()
+ if len(groups) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
+ defer cancel()
+
+ client := &http.Client{Timeout: httpTimeout}
+ results := make([]UptimeGroupResult, len(groups))
+
+ g, gCtx := errgroup.WithContext(ctx)
+ for i, group := range groups {
+ i, group := i, group
+ g.Go(func() error {
+ results[i] = fetchGroupData(gCtx, client, group)
+ return nil
+ })
+ }
+
+ g.Wait()
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
+}
\ No newline at end of file
diff --git a/controller/user.go b/controller/user.go
index fd53e743..e8ce3c3d 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -459,6 +462,9 @@ func GetSelf(c *gin.Context) {
})
return
}
+ // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
+ user.Remark = ""
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -943,6 +949,7 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
+ RecordIpLog bool `json:"record_ip_log"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -1019,6 +1026,7 @@ func UpdateUserSetting(c *gin.Context) {
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
+ constant.UserSettingRecordIpLog: req.RecordIpLog,
}
// 如果是webhook类型,添加webhook相关设置
diff --git a/dto/claude.go b/dto/claude.go
index 8068feb8..98e09c78 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -1,29 +1,33 @@
package dto
-import "encoding/json"
+import (
+ "encoding/json"
+ "one-api/common"
+)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
- Type string `json:"type,omitempty"`
- Text *string `json:"text,omitempty"`
- Model string `json:"model,omitempty"`
- Source *ClaudeMessageSource `json:"source,omitempty"`
- Usage *ClaudeUsage `json:"usage,omitempty"`
- StopReason *string `json:"stop_reason,omitempty"`
- PartialJson *string `json:"partial_json,omitempty"`
- Role string `json:"role,omitempty"`
- Thinking string `json:"thinking,omitempty"`
- Signature string `json:"signature,omitempty"`
- Delta string `json:"delta,omitempty"`
+ Type string `json:"type,omitempty"`
+ Text *string `json:"text,omitempty"`
+ Model string `json:"model,omitempty"`
+ Source *ClaudeMessageSource `json:"source,omitempty"`
+ Usage *ClaudeUsage `json:"usage,omitempty"`
+ StopReason *string `json:"stop_reason,omitempty"`
+ PartialJson *string `json:"partial_json,omitempty"`
+ Role string `json:"role,omitempty"`
+ Thinking string `json:"thinking,omitempty"`
+ Signature string `json:"signature,omitempty"`
+ Delta string `json:"delta,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
// tool_calls
- Id string `json:"id,omitempty"`
- Name string `json:"name,omitempty"`
- Input any `json:"input,omitempty"`
- Content json.RawMessage `json:"content,omitempty"`
- ToolUseId string `json:"tool_use_id,omitempty"`
+ Id string `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Input any `json:"input,omitempty"`
+ Content any `json:"content,omitempty"`
+ ToolUseId string `json:"tool_use_id,omitempty"`
}
func (c *ClaudeMediaMessage) SetText(s string) {
@@ -38,15 +42,39 @@ func (c *ClaudeMediaMessage) GetText() string {
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
- var content string
- return json.Unmarshal(c.Content, &content) == nil
+ if c.Content == nil {
+ return false
+ }
+ _, ok := c.Content.(string)
+ if ok {
+ return true
+ }
+ return false
}
func (c *ClaudeMediaMessage) GetStringContent() string {
- var content string
- if err := json.Unmarshal(c.Content, &content); err == nil {
- return content
+ if c.Content == nil {
+ return ""
}
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
return ""
}
@@ -56,16 +84,12 @@ func (c *ClaudeMediaMessage) GetJsonRowString() string {
}
func (c *ClaudeMediaMessage) SetContent(content any) {
- jsonContent, _ := json.Marshal(content)
- c.Content = jsonContent
+ c.Content = content
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
- var mediaContent []ClaudeMediaMessage
- if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
- return mediaContent
- }
- return make([]ClaudeMediaMessage, 0)
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
+ return mediaContent
}
type ClaudeMessageSource struct {
@@ -81,14 +105,36 @@ type ClaudeMessage struct {
}
func (c *ClaudeMessage) IsStringContent() bool {
+ if c.Content == nil {
+ return false
+ }
_, ok := c.Content.(string)
return ok
}
func (c *ClaudeMessage) GetStringContent() string {
- if c.IsStringContent() {
- return c.Content.(string)
+ if c.Content == nil {
+ return ""
}
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
return ""
}
@@ -97,15 +143,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
- // map content to []ClaudeMediaMessage
- // parse to json
- jsonContent, _ := json.Marshal(c.Content)
- var contentList []ClaudeMediaMessage
- err := json.Unmarshal(jsonContent, &contentList)
- if err != nil {
- return make([]ClaudeMediaMessage, 0), err
- }
- return contentList, nil
+ return common.Any2Type[[]ClaudeMediaMessage](c.Content)
}
type Tool struct {
@@ -140,7 +178,14 @@ type ClaudeRequest struct {
type Thinking struct {
Type string `json:"type"`
- BudgetTokens int `json:"budget_tokens"`
+ BudgetTokens *int `json:"budget_tokens,omitempty"`
+}
+
+func (c *Thinking) GetBudgetTokens() int {
+ if c.BudgetTokens == nil {
+ return 0
+ }
+ return *c.BudgetTokens
}
func (c *ClaudeRequest) IsStringSystem() bool {
@@ -160,14 +205,8 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
}
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
- // map content to []ClaudeMediaMessage
- // parse to json
- jsonContent, _ := json.Marshal(c.System)
- var contentList []ClaudeMediaMessage
- if err := json.Unmarshal(jsonContent, &contentList); err == nil {
- return contentList
- }
- return make([]ClaudeMediaMessage, 0)
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
+ return mediaContent
}
type ClaudeError struct {
diff --git a/dto/dalle.go b/dto/dalle.go
index 44104d33..ce2f6361 100644
--- a/dto/dalle.go
+++ b/dto/dalle.go
@@ -14,6 +14,8 @@ type ImageRequest struct {
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
+ OutputFormat string `json:"output_format,omitempty"`
+ Watermark *bool `json:"watermark,omitempty"`
}
type ImageResponse struct {
diff --git a/dto/openai_request.go b/dto/openai_request.go
index e8833b3d..42c290ca 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -2,6 +2,7 @@ package dto
import (
"encoding/json"
+ "one-api/common"
"strings"
)
@@ -18,41 +19,54 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Prefix any `json:"prefix,omitempty"`
- Suffix any `json:"suffix,omitempty"`
- Stream bool `json:"stream,omitempty"`
- StreamOptions *StreamOptions `json:"stream_options,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
- //Reasoning json.RawMessage `json:"reasoning,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Stop any `json:"stop,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- EncodingFormat any `json:"encoding_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Tools []ToolCallRequest `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
- LogProbs bool `json:"logprobs,omitempty"`
- TopLogProbs int `json:"top_logprobs,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
- Modalities any `json:"modalities,omitempty"`
- Audio any `json:"audio,omitempty"`
- EnableThinking any `json:"enable_thinking,omitempty"` // ali
- ExtraBody any `json:"extra_body,omitempty"`
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Prefix any `json:"prefix,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *StreamOptions `json:"stream_options,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop any `json:"stop,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions json.RawMessage `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
+ Tools []ToolCallRequest `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+ LogProbs bool `json:"logprobs,omitempty"`
+ TopLogProbs int `json:"top_logprobs,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ Modalities json.RawMessage `json:"modalities,omitempty"`
+ Audio json.RawMessage `json:"audio,omitempty"`
+ EnableThinking any `json:"enable_thinking,omitempty"` // ali
+ THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
+ ExtraBody json.RawMessage `json:"extra_body,omitempty"`
+ WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
+ // OpenRouter Params
+ Reasoning json.RawMessage `json:"reasoning,omitempty"`
+ // Ali Qwen Params
+ VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
+}
+
+func (r *GeneralOpenAIRequest) ToMap() map[string]any {
+ result := make(map[string]any)
+ data, _ := common.EncodeJson(r)
+ _ = common.DecodeJson(data, &result)
+ return result
}
type ToolCallRequest struct {
@@ -72,11 +86,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
-func (r GeneralOpenAIRequest) GetMaxTokens() int {
+func (r *GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
}
-func (r GeneralOpenAIRequest) ParseInput() []string {
+func (r *GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
@@ -96,16 +110,16 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
- Role string `json:"role"`
- Content json.RawMessage `json:"content"`
- Name *string `json:"name,omitempty"`
- Prefix *bool `json:"prefix,omitempty"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- Reasoning string `json:"reasoning,omitempty"`
- ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
- ToolCallId string `json:"tool_call_id,omitempty"`
- parsedContent []MediaContent
- parsedStringContent *string
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+ Prefix *bool `json:"prefix,omitempty"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ Reasoning string `json:"reasoning,omitempty"`
+ ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
+ ToolCallId string `json:"tool_call_id,omitempty"`
+ parsedContent []MediaContent
+ //parsedStringContent *string
}
type MediaContent struct {
@@ -115,25 +129,56 @@ type MediaContent struct {
InputAudio any `json:"input_audio,omitempty"`
File any `json:"file,omitempty"`
VideoUrl any `json:"video_url,omitempty"`
+ // OpenRouter Params
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
- return m.ImageUrl.(*MessageImageUrl)
+ if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
+ return m.ImageUrl.(*MessageImageUrl)
+ }
+ if itemMap, ok := m.ImageUrl.(map[string]any); ok {
+ out := &MessageImageUrl{
+ Url: common.Interface2String(itemMap["url"]),
+ Detail: common.Interface2String(itemMap["detail"]),
+ MimeType: common.Interface2String(itemMap["mime_type"]),
+ }
+ return out
+ }
}
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
- return m.InputAudio.(*MessageInputAudio)
+ if _, ok := m.InputAudio.(*MessageInputAudio); ok {
+ return m.InputAudio.(*MessageInputAudio)
+ }
+ if itemMap, ok := m.InputAudio.(map[string]any); ok {
+ out := &MessageInputAudio{
+ Data: common.Interface2String(itemMap["data"]),
+ Format: common.Interface2String(itemMap["format"]),
+ }
+ return out
+ }
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
- return m.File.(*MessageFile)
+ if _, ok := m.File.(*MessageFile); ok {
+ return m.File.(*MessageFile)
+ }
+ if itemMap, ok := m.File.(map[string]any); ok {
+ out := &MessageFile{
+ FileName: common.Interface2String(itemMap["file_name"]),
+ FileData: common.Interface2String(itemMap["file_data"]),
+ FileId: common.Interface2String(itemMap["file_id"]),
+ }
+ return out
+ }
}
return nil
}
@@ -199,6 +244,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
+ switch m.Content.(type) {
+ case string:
+ return m.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range m.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
+ return ""
+}
+
+func (m *Message) SetNullContent() {
+ m.Content = nil
+ m.parsedContent = nil
+}
+
+func (m *Message) SetStringContent(content string) {
+ m.Content = content
+ m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+ m.Content = content
+ m.parsedContent = content
+}
+
+func (m *Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ if ok {
+ return true
+ }
+ return false
+}
+
+func (m *Message) ParseContent() []MediaContent {
+ if m.Content == nil {
+ return nil
+ }
+ if len(m.parsedContent) > 0 {
+ return m.parsedContent
+ }
+
+ var contentList []MediaContent
+ // 先尝试解析为字符串
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = []MediaContent{{
+ Type: ContentTypeText,
+ Text: content,
+ }}
+ m.parsedContent = contentList
+ return contentList
+ }
+
+ // 尝试解析为数组
+ //var arrayContent []map[string]interface{}
+
+ arrayContent, ok := m.Content.([]any)
+ if !ok {
+ return contentList
+ }
+
+ for _, contentItemAny := range arrayContent {
+ mediaItem, ok := contentItemAny.(MediaContent)
+ if ok {
+ contentList = append(contentList, mediaItem)
+ continue
+ }
+
+ contentItem, ok := contentItemAny.(map[string]any)
+ if !ok {
+ continue
+ }
+ contentType, ok := contentItem["type"].(string)
+ if !ok {
+ continue
+ }
+
+ switch contentType {
+ case ContentTypeText:
+ if text, ok := contentItem["text"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeText,
+ Text: text,
+ })
+ }
+
+ case ContentTypeImageURL:
+ imageUrl := contentItem["image_url"]
+ temp := &MessageImageUrl{
+ Detail: "high",
+ }
+ switch v := imageUrl.(type) {
+ case string:
+ temp.Url = v
+ case map[string]interface{}:
+ url, ok1 := v["url"].(string)
+ detail, ok2 := v["detail"].(string)
+ if ok2 {
+ temp.Detail = detail
+ }
+ if ok1 {
+ temp.Url = url
+ }
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeImageURL,
+ ImageUrl: temp,
+ })
+
+ case ContentTypeInputAudio:
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+ data, ok1 := audioData["data"].(string)
+ format, ok2 := audioData["format"].(string)
+ if ok1 && ok2 {
+ temp := &MessageInputAudio{
+ Data: data,
+ Format: format,
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeInputAudio,
+ InputAudio: temp,
+ })
+ }
+ }
+ case ContentTypeFile:
+ if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+ fileId, ok3 := fileData["file_id"].(string)
+ if ok3 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileId: fileId,
+ },
+ })
+ } else {
+ fileName, ok1 := fileData["filename"].(string)
+ fileDataStr, ok2 := fileData["file_data"].(string)
+ if ok1 && ok2 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileName: fileName,
+ FileData: fileDataStr,
+ },
+ })
+ }
+ }
+ }
+ case ContentTypeVideoUrl:
+ if videoUrl, ok := contentItem["video_url"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeVideoUrl,
+ VideoUrl: &MessageVideoUrl{
+ Url: videoUrl,
+ },
+ })
+ }
+ }
+ }
+
+ if len(contentList) > 0 {
+ m.parsedContent = contentList
+ }
+ return contentList
+}
+
+// old code
+/*func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
@@ -369,6 +594,11 @@ func (m *Message) ParseContent() []MediaContent {
m.parsedContent = contentList
}
return contentList
+}*/
+
+type WebSearchOptions struct {
+ SearchContextSize string `json:"search_context_size,omitempty"`
+ UserLocation json.RawMessage `json:"user_location,omitempty"`
}
type OpenAIResponsesRequest struct {
diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go
new file mode 100644
index 00000000..55a89025
--- /dev/null
+++ b/dto/ratio_sync.go
@@ -0,0 +1,49 @@
+package dto
+
+// UpstreamDTO 提交到后端同步倍率的上游渠道信息
+// Endpoint 可以为空,后端会默认使用 /api/ratio_config
+// BaseURL 必须以 http/https 开头,不要以 / 结尾
+// 例如: https://api.example.com
+// Endpoint: /api/ratio_config
+// 提交示例:
+// {
+// "name": "openai",
+// "base_url": "https://api.openai.com",
+// "endpoint": "/ratio_config"
+// }
+
+type UpstreamDTO struct {
+ Name string `json:"name" binding:"required"`
+ BaseURL string `json:"base_url" binding:"required"`
+ Endpoint string `json:"endpoint"`
+}
+
+type UpstreamRequest struct {
+ ChannelIDs []int64 `json:"channel_ids"`
+ Timeout int `json:"timeout"`
+}
+
+// TestResult 上游测试连通性结果
+type TestResult struct {
+ Name string `json:"name"`
+ Status string `json:"status"`
+ Error string `json:"error,omitempty"`
+}
+
+// DifferenceItem 差异项
+// Current 为本地值,可能为 nil
+// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
+
+type DifferenceItem struct {
+ Current interface{} `json:"current"`
+ Upstreams map[string]interface{} `json:"upstreams"`
+}
+
+// SyncableChannel 可同步的渠道信息(base_url 不为空)
+
+type SyncableChannel struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ BaseURL string `json:"base_url"`
+ Status int `json:"status"`
+}
\ No newline at end of file
diff --git a/dto/video.go b/dto/video.go
new file mode 100644
index 00000000..5b48146a
--- /dev/null
+++ b/dto/video.go
@@ -0,0 +1,47 @@
+package dto
+
+type VideoRequest struct {
+ Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
+ Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
+ Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
+ Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
+ Width int `json:"width" example:"512"` // Video width
+ Height int `json:"height" example:"512"` // Video height
+ Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
+ Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
+ N int `json:"n,omitempty" example:"1"` // Number of videos to generate
+ ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
+ User string `json:"user,omitempty" example:"user-1234"` // User identifier
+ Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
+}
+
+// VideoResponse 视频生成提交任务后的响应
+type VideoResponse struct {
+ TaskId string `json:"task_id"`
+ Status string `json:"status"`
+}
+
+// VideoTaskResponse 查询视频生成任务状态的响应
+type VideoTaskResponse struct {
+ TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
+ Status string `json:"status" example:"succeeded"` // 任务状态
+ Url string `json:"url,omitempty"` // 视频资源URL(成功时)
+ Format string `json:"format,omitempty" example:"mp4"` // 视频格式
+ Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
+ Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
+}
+
+// VideoTaskMetadata 视频任务元数据
+type VideoTaskMetadata struct {
+ Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
+ Fps int `json:"fps" example:"30"` // 实际帧率
+ Width int `json:"width" example:"512"` // 实际宽度
+ Height int `json:"height" example:"512"` // 实际高度
+ Seed int `json:"seed" example:"20231234"` // 使用的随机种子
+}
+
+// VideoTaskError 视频任务错误信息
+type VideoTaskError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
diff --git a/go.mod b/go.mod
index ce768bf3..9479ba55 100644
--- a/go.mod
+++ b/go.mod
@@ -11,7 +11,6 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
- github.com/bytedance/sonic v1.11.6
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -25,10 +24,10 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
- github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
+ github.com/tiktoken-go/tokenizer v0.6.2
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
@@ -43,12 +42,13 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
+ github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
- github.com/dlclark/regexp2 v1.11.0 // indirect
+ github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
diff --git a/go.sum b/go.sum
index 2bd81fa3..71dd83c2 100644
--- a/go.sum
+++ b/go.sum
@@ -38,8 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
-github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
-github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
+github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -167,8 +167,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
-github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
@@ -197,6 +195,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
+github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
diff --git a/main.go b/main.go
index 95c6820d..cf593b57 100644
--- a/main.go
+++ b/main.go
@@ -12,7 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"os"
"strconv"
@@ -74,7 +74,7 @@ func main() {
}
// Initialize model settings
- operation_setting.InitRatioSettings()
+ ratio_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
@@ -89,13 +89,28 @@ func main() {
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
- model.InitChannelCache()
- }
- if common.MemoryCacheEnabled {
- go model.SyncOptions(common.SyncFrequency)
+
+ // Add panic recovery and retry for InitChannelCache
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+ // Retry once
+ _, fixErr := model.FixAbility()
+ if fixErr != nil {
+ common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+ }
+ }
+ }()
+ model.InitChannelCache()
+ }()
+
go model.SyncChannelCache(common.SyncFrequency)
}
+ // 热更新配置
+ go model.SyncOptions(common.SyncFrequency)
+
// 数据看板
go model.UpdateQuotaData()
diff --git a/makefile b/makefile
index 5042723c..cbc4ea6a 100644
--- a/makefile
+++ b/makefile
@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
- @cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
+ @cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
start-backend:
@echo "Starting backend dev server..."
diff --git a/middleware/auth.go b/middleware/auth.go
index fece4553..f387029f 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -1,13 +1,14 @@
package middleware
import (
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
)
func validUserInfo(username string, role int) bool {
@@ -182,6 +183,18 @@ func TokenAuth() func(c *gin.Context) {
c.Request.Header.Set("Authorization", "Bearer "+key)
}
}
+ // gemini api 从query中获取key
+ if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+ skKey := c.Query("key")
+ if skKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+skKey)
+ }
+ // 从x-goog-api-key header中获取key
+ xGoogKey := c.Request.Header.Get("x-goog-api-key")
+ if xGoogKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
+ }
+ }
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ")
diff --git a/middleware/distributor.go b/middleware/distributor.go
index e7db6d77..9d074ce8 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -11,6 +11,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
+ "one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
return
}
// check group in common.GroupRatio
- if !setting.ContainsGroupRatio(tokenGroup) {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
+ if !ratio_setting.ContainsGroupRatio(tokenGroup) {
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
}
userGroup = tokenGroup
}
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
+ var selectGroup string
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -162,6 +170,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
+ } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
+ relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeKlingFetchByID {
+ shouldSelectChannel = false
+ } else {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ }
+ c.Set("platform", string(constant.TaskPlatformKling))
+ c.Set("relay_mode", relayMode)
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+ // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
+ relayMode := relayconstant.RelayModeGemini
+ modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
+ if modelName != "" {
+ modelRequest.Model = modelName
+ }
+ c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
@@ -244,3 +269,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("bot_id", channel.Other)
}
}
+
+// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
+// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
+// 输出: gemini-2.0-flash
+func extractModelNameFromGeminiPath(path string) string {
+ // 查找 "/models/" 的位置
+ modelsPrefix := "/models/"
+ modelsIndex := strings.Index(path, modelsPrefix)
+ if modelsIndex == -1 {
+ return ""
+ }
+
+ // 从 "/models/" 之后开始提取
+ startIndex := modelsIndex + len(modelsPrefix)
+ if startIndex >= len(path) {
+ return ""
+ }
+
+ // 查找 ":" 的位置,模型名在 ":" 之前
+ colonIndex := strings.Index(path[startIndex:], ":")
+ if colonIndex == -1 {
+ // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
+ return path[startIndex:]
+ }
+
+ // 返回模型名部分
+ return path[startIndex : startIndex+colonIndex]
+}
diff --git a/middleware/stats.go b/middleware/stats.go
new file mode 100644
index 00000000..1c97983f
--- /dev/null
+++ b/middleware/stats.go
@@ -0,0 +1,41 @@
+package middleware
+
+import (
+ "sync/atomic"
+
+ "github.com/gin-gonic/gin"
+)
+
+// HTTPStats 存储HTTP统计信息
+type HTTPStats struct {
+ activeConnections int64
+}
+
+var globalStats = &HTTPStats{}
+
+// StatsMiddleware 统计中间件
+func StatsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 增加活跃连接数
+ atomic.AddInt64(&globalStats.activeConnections, 1)
+
+ // 确保在请求结束时减少连接数
+ defer func() {
+ atomic.AddInt64(&globalStats.activeConnections, -1)
+ }()
+
+ c.Next()
+ }
+}
+
+// StatsInfo 统计信息结构
+type StatsInfo struct {
+ ActiveConnections int64 `json:"active_connections"`
+}
+
+// GetStats 获取统计信息
+func GetStats() StatsInfo {
+ return StatsInfo{
+ ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
+ }
+}
\ No newline at end of file
diff --git a/model/ability.go b/model/ability.go
index 52720307..96a9ef6a 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -8,6 +8,7 @@ import (
"github.com/samber/lo"
"gorm.io/gorm"
+ "gorm.io/gorm/clause"
)
type Ability struct {
@@ -23,7 +24,7 @@ type Ability struct {
func GetGroupModels(group string) []string {
var models []string
// Find distinct models
- DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+ DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -41,15 +42,11 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
- Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
+ Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
@@ -75,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
- channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+ maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
+ channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
- channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+ channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
}
}
@@ -133,9 +126,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -152,7 +151,7 @@ func (channel *Channel) AddAbilities() error {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
- err := DB.Create(&chunk).Error
+ err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
return err
}
@@ -194,9 +193,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
// Then add new abilities
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -212,7 +217,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
- err = tx.Create(&chunk).Error
+ err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
@@ -261,12 +266,28 @@ func FixAbility() (int, error) {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
- // Delete abilities of channels that are not in channel table
- err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
- return 0, err
+
+ // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
+ if len(channelIds) > 0 {
+ // Process deletion in chunks to avoid "too many placeholders" error
+ for _, chunk := range lo.Chunk(channelIds, 100) {
+ err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
+ if err != nil {
+ common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
+ return 0, err
+ }
+ }
+ } else {
+ // If no channels exist, delete all abilities
+ err = DB.Delete(&Ability{}).Error
+ if err != nil {
+ common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
+ return 0, err
+ }
+ common.SysLog("Delete all abilities successfully")
+ return 0, nil
}
+
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
@@ -275,17 +296,26 @@ func FixAbility() (int, error) {
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
- return 0, err
+ return count, err
}
+
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
- err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
- }
- if err != nil {
- return 0, err
+ // Process query in chunks to avoid "too many placeholders" error
+ err = nil
+ for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
+ var channelsChunk []Channel
+ err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
+ if err != nil {
+ common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
+ return count, err
+ }
+ channels = append(channels, channelsChunk...)
+ }
}
+
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
diff --git a/model/cache.go b/model/cache.go
index 2d1c36bf..3e5eb4c4 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
+ "one-api/setting"
"sort"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -16,6 +19,9 @@ var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
+ if !common.MemoryCacheEnabled {
+ return
+ }
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
@@ -72,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ if common.DebugEnabled {
+ println("autoGroup:", autoGroup)
+ }
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ if common.DebugEnabled {
+ println("selectGroup:", selectGroup)
+ }
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ if channel == nil {
+ return nil, group, errors.New("channel not found")
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
@@ -84,11 +126,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
}
-
+
channelSyncLock.RLock()
channels := group2model2channels[group][model]
channelSyncLock.RUnlock()
-
+
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
diff --git a/model/channel.go b/model/channel.go
index 41e5e371..6cbd8adc 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string {
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
+func (channel *Channel) GetGroups() []string {
+ if channel.Group == "" {
+ return []string{}
+ }
+ groups := strings.Split(strings.Trim(channel.Group, ","), ",")
+ for i, group := range groups {
+ groups[i] = strings.TrimSpace(group)
+ }
+ return groups
+}
+
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
@@ -134,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -142,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -467,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -475,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -572,3 +583,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
// 提交事务
return tx.Commit().Error
}
+
+// CountAllChannels returns total channels in DB
+func CountAllChannels() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Count(&total).Error
+ return total, err
+}
+
+// CountAllTags returns number of non-empty distinct tags
+func CountAllTags() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
+ return total, err
+}
+
+// Get channels of specified type with pagination
+func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
+ var channels []*Channel
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+ err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ return channels, err
+}
+
+// Count channels of specific type
+func CountChannelsByType(channelType int) (int64, error) {
+ var count int64
+ err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
+ return count, err
+}
+
+// Return map[type]count for all channels
+func CountChannelsGroupByType() (map[int64]int64, error) {
+ type result struct {
+ Type int64 `gorm:"column:type"`
+ Count int64 `gorm:"column:count"`
+ }
+ var results []result
+ err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
+ if err != nil {
+ return nil, err
+ }
+ counts := make(map[int64]int64)
+ for _, r := range results {
+ counts[r.Type] = r.Count
+ }
+ return counts, nil
+}
diff --git a/model/log.go b/model/log.go
index 0a891fcd..b3fd1ad2 100644
--- a/model/log.go
+++ b/model/log.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
+ "one-api/constant"
"os"
"strings"
"time"
@@ -32,6 +33,7 @@ type Log struct {
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
+ Ip string `json:"ip" gorm:"index;default:''"`
Other string `json:"other"`
}
@@ -61,7 +63,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
- if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+ if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -95,6 +97,15 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+ if vb, ok := v.(bool); ok && vb {
+ needRecordIp = true
+ }
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
@@ -111,7 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
- Other: otherStr,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -128,6 +145,15 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+ if vb, ok := v.(bool); ok && vb {
+ needRecordIp = true
+ }
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
@@ -144,7 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
- Other: otherStr,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -184,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -195,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return nil, 0, err
}
- channelIds := make([]int, 0)
+ channelIdsMap := make(map[int]struct{})
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
- channelIds = append(channelIds, log.ChannelId)
+ channelIdsMap[log.ChannelId] = struct{}{}
}
}
+
+ channelIds := make([]int, 0, len(channelIdsMap))
+ for channelId := range channelIdsMap {
+ channelIds = append(channelIds, channelId)
+ }
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
@@ -242,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -303,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where(groupCol+" = ?", group)
- rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
+ tx = tx.Where(logGroupCol+" = ?", group)
+ rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)
diff --git a/model/main.go b/model/main.go
index 61d6bb10..d46a21cf 100644
--- a/model/main.go
+++ b/model/main.go
@@ -1,6 +1,7 @@
package model
import (
+ "fmt"
"log"
"one-api/common"
"one-api/constant"
@@ -15,18 +16,48 @@ import (
"gorm.io/gorm"
)
-var groupCol string
-var keyCol string
+var commonGroupCol string
+var commonKeyCol string
+var commonTrueVal string
+var commonFalseVal string
+
+var logKeyCol string
+var logGroupCol string
func initCol() {
+ // init common column names
if common.UsingPostgreSQL {
- groupCol = `"group"`
- keyCol = `"key"`
-
+ commonGroupCol = `"group"`
+ commonKeyCol = `"key"`
+ commonTrueVal = "true"
+ commonFalseVal = "false"
} else {
- groupCol = "`group`"
- keyCol = "`key`"
+ commonGroupCol = "`group`"
+ commonKeyCol = "`key`"
+ commonTrueVal = "1"
+ commonFalseVal = "0"
}
+ if os.Getenv("LOG_SQL_DSN") != "" {
+ switch common.LogSqlType {
+ case common.DatabaseTypePostgreSQL:
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ default:
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ } else {
+ // LOG_SQL_DSN 为空时,日志数据库与主数据库相同
+ if common.UsingPostgreSQL {
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ } else {
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ }
+ // log sql type and database type
+ common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
@@ -83,7 +114,7 @@ func CheckSetup() {
}
}
-func chooseDB(envName string) (*gorm.DB, error) {
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
defer func() {
initCol()
}()
@@ -92,7 +123,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
- common.UsingPostgreSQL = true
+ if !isLog {
+ common.UsingPostgreSQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypePostgreSQL
+ }
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +137,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
- common.UsingSQLite = true
+ if !isLog {
+ common.UsingSQLite = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeSQLite
+ }
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -117,7 +156,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
- common.UsingMySQL = true
+ if !isLog {
+ common.UsingMySQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeMySQL
+ }
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -131,7 +174,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
func InitDB() (err error) {
- db, err := chooseDB("SQL_DSN")
+ db, err := chooseDB("SQL_DSN", false)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -149,7 +192,7 @@ func InitDB() (err error) {
return nil
}
if common.UsingMySQL {
- _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+ //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
@@ -165,7 +208,7 @@ func InitLogDB() (err error) {
LOG_DB = DB
return
}
- db, err := chooseDB("LOG_SQL_DSN")
+ db, err := chooseDB("LOG_SQL_DSN", true)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -198,54 +241,73 @@ func InitLogDB() (err error) {
}
func migrateDB() error {
- err := DB.AutoMigrate(&Channel{})
+ if !common.UsingPostgreSQL {
+ return migrateDBFast()
+ }
+ err := DB.AutoMigrate(
+ &Channel{},
+ &Token{},
+ &User{},
+ &Option{},
+ &Redemption{},
+ &Ability{},
+ &Log{},
+ &Midjourney{},
+ &TopUp{},
+ &QuotaData{},
+ &Task{},
+ &Setup{},
+ )
if err != nil {
return err
}
- err = DB.AutoMigrate(&Token{})
- if err != nil {
- return err
+ return nil
+}
+
+func migrateDBFast() error {
+ var wg sync.WaitGroup
+ errChan := make(chan error, 12) // Buffer size matches number of migrations
+
+ migrations := []struct {
+ model interface{}
+ name string
+ }{
+ {&Channel{}, "Channel"},
+ {&Token{}, "Token"},
+ {&User{}, "User"},
+ {&Option{}, "Option"},
+ {&Redemption{}, "Redemption"},
+ {&Ability{}, "Ability"},
+ {&Log{}, "Log"},
+ {&Midjourney{}, "Midjourney"},
+ {&TopUp{}, "TopUp"},
+ {&QuotaData{}, "QuotaData"},
+ {&Task{}, "Task"},
+ {&Setup{}, "Setup"},
}
- err = DB.AutoMigrate(&User{})
- if err != nil {
- return err
+
+ for _, m := range migrations {
+ wg.Add(1)
+ go func(model interface{}, name string) {
+ defer wg.Done()
+ if err := DB.AutoMigrate(model); err != nil {
+ errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+ }
+ }(m.model, m.name)
}
- err = DB.AutoMigrate(&Option{})
- if err != nil {
- return err
+
+ // Wait for all migrations to complete
+ wg.Wait()
+ close(errChan)
+
+ // Check for any errors
+ for err := range errChan {
+ if err != nil {
+ return err
+ }
}
- err = DB.AutoMigrate(&Redemption{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Ability{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Log{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Midjourney{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&TopUp{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&QuotaData{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Task{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
- //err = createRootAccountIfNeed()
- return err
+ return nil
}
func migrateLOGDB() error {
diff --git a/model/midjourney.go b/model/midjourney.go
index 5f85abfd..e8140447 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
Where("id in (?)", taskIDs).
Updates(params).Error
}
+
+// CountAllTasks returns total midjourney tasks for admin query
+func CountAllTasks(queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// CountAllUserTask returns total midjourney tasks for user
+func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/option.go b/model/option.go
index d892b120..ea72e5ee 100644
--- a/model/option.go
+++ b/model/option.go
@@ -5,6 +5,7 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -76,6 +77,9 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
+ common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -94,12 +98,13 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
- common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
- common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
- common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
- common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+ common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
+ common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
+ common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
+ common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
+ common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
- common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
+ common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -122,6 +127,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+ common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -191,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
- if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -260,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = boolValue
+ case "ExposeRatioEnabled":
+ ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -286,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -351,17 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
- err = operation_setting.UpdateModelRatioByJSONString(value)
+ err = ratio_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
- err = setting.UpdateGroupRatioByJSONString(value)
+ err = ratio_setting.UpdateGroupRatioByJSONString(value)
+ case "GroupGroupRatio":
+ err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
- err = operation_setting.UpdateCompletionRatioByJSONString(value)
+ err = ratio_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
- err = operation_setting.UpdateModelPriceByJSONString(value)
+ err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
- err = operation_setting.UpdateCacheRatioByJSONString(value)
+ err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -378,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+ case "PayMethods":
+ err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}
diff --git a/model/pricing.go b/model/pricing.go
index ba1815e2..74a25f2d 100644
--- a/model/pricing.go
+++ b/model/pricing.go
@@ -2,7 +2,7 @@ package model
import (
"one-api/common"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"sync"
"time"
)
@@ -65,14 +65,14 @@ func updatePricing() {
ModelName: model,
EnableGroup: groups,
}
- modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
+ modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
- modelRatio, _ := operation_setting.GetModelRatio(model)
+ modelRatio, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
- pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
+ pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
diff --git a/model/redemption.go b/model/redemption.go
index 89c4ac8c..bf237668 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -21,6 +21,7 @@ type Redemption struct {
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
}
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
@@ -131,6 +132,9 @@ func Redeem(key string, userId int) (quota int, err error) {
if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
+ if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
+ return errors.New("该兑换码已过期")
+ }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil {
return err
@@ -162,7 +166,7 @@ func (redemption *Redemption) SelectUpdate() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (redemption *Redemption) Update() error {
var err error
- err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
+ err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
return err
}
@@ -183,3 +187,9 @@ func DeleteRedemptionById(id int) (err error) {
}
return redemption.Delete()
}
+
+func DeleteInvalidRedemptions() (int64, error) {
+ now := common.GetTimestamp()
+ result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
+ return result.RowsAffected, result.Error
+}
diff --git a/model/task.go b/model/task.go
index df221edf..9e4177ba 100644
--- a/model/task.go
+++ b/model/task.go
@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}
+
+// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
+func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// TaskCountAllUserTask returns total tasks for given user
+func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{}).Where("user_id = ?", userId)
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/token.go b/model/token.go
index 8587ea62..2ed2c09a 100644
--- a/model/token.go
+++ b/model/token.go
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+ err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Where(keyCol+" = ?", key).First(&token).Error
+ err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
return token, err
}
@@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
+
+// CountUserTokens returns total number of tokens for the given user, used for pagination
+func CountUserTokens(userId int) (int64, error) {
+ var total int64
+ err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
+ return total, err
+}
diff --git a/model/token_cache.go b/model/token_cache.go
index 0fe02fea..a4b0beae 100644
--- a/model/token_cache.go
+++ b/model/token_cache.go
@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
- err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
+ err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}
@@ -19,7 +19,7 @@ func cacheSetToken(token Token) error {
func cacheDeleteToken(key string) error {
key = common.GenerateHMAC(key)
- err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
+ err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
if err != nil {
return err
}
diff --git a/model/user.go b/model/user.go
index 1a3372aa..6a695457 100644
--- a/model/user.go
+++ b/model/user.go
@@ -41,6 +41,7 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
+ Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
}
func (user *User) ToBaseUser() *UserBase {
@@ -175,7 +176,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -184,7 +185,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
+ "remark": newUser.Remark,
}
if updatePassword {
updates["password"] = newUser.Password
@@ -615,7 +617,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
+ err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
if err != nil {
return "", err
}
diff --git a/model/user_cache.go b/model/user_cache.go
index bc412e77..e673defc 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -3,11 +3,12 @@ package model
import (
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"time"
+ "github.com/gin-gonic/gin"
+
"github.com/bytedance/gopkg/util/gopool"
)
@@ -57,7 +58,7 @@ func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
- return common.RedisHDelObj(getUserCacheKey(userId))
+ return common.RedisDelKey(getUserCacheKey(userId))
}
// updateUserCache updates all user cache fields using hash
@@ -69,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
+ time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
)
}
diff --git a/model/utils.go b/model/utils.go
index e6b09aa5..1f8a0963 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -2,11 +2,12 @@ package model
import (
"errors"
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
"one-api/common"
"sync"
"time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
)
const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
+ // check if there's any data to update
+ hasData := false
+ for i := 0; i < BatchUpdateTypeCount; i++ {
+ batchUpdateLocks[i].Lock()
+ if len(batchUpdateStores[i]) > 0 {
+ hasData = true
+ batchUpdateLocks[i].Unlock()
+ break
+ }
+ batchUpdateLocks[i].Unlock()
+ }
+
+ if !hasData {
+ return
+ }
+
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
diff --git a/relay/relay-audio.go b/relay/audio_handler.go
similarity index 91%
rename from relay/relay-audio.go
rename to relay/audio_handler.go
index deb45c58..c1ce1a02 100644
--- a/relay/relay-audio.go
+++ b/relay/audio_handler.go
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- relayInfo := relaycommon.GenRelayInfo(c)
+ relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
- promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
- }
+ promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- audioRequest.Model = relayInfo.UpstreamModelName
-
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index 50255d0a..873997f6 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+
+ ParseResultUrl(resp map[string]any) (string, error)
}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index ab632d22..f30d4dc4 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
+ case constant.RelayModeRerank:
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
@@ -57,6 +59,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
+
+ // fix: ali parameter.enable_thinking must be set to false for non-streaming calls
+ if !info.IsStream {
+ request.EnableThinking = false
+ }
+
switch info.RelayMode {
default:
aliReq := requestOpenAI2Ali(*request)
@@ -70,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("not implemented")
+ return ConvertRerankRequest(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -97,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
+ case constant.RelayModeRerank:
+ err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go
index 46de5e40..df64439b 100644
--- a/relay/channel/ali/constants.go
+++ b/relay/channel/ali/constants.go
@@ -8,6 +8,7 @@ var ModelList = []string{
"qwq-32b",
"qwen3-235b-a22b",
"text-embedding-v1",
+ "gte-rerank-v2",
}
var ChannelName = "ali"
diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go
index f51286ad..dbd18968 100644
--- a/relay/channel/ali/dto.go
+++ b/relay/channel/ali/dto.go
@@ -1,5 +1,7 @@
package ali
+import "one-api/dto"
+
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -97,3 +99,28 @@ type AliImageRequest struct {
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
+
+type AliRerankParameters struct {
+ TopN *int `json:"top_n,omitempty"`
+ ReturnDocuments *bool `json:"return_documents,omitempty"`
+}
+
+type AliRerankInput struct {
+ Query string `json:"query"`
+ Documents []any `json:"documents"`
+}
+
+type AliRerankRequest struct {
+ Model string `json:"model"`
+ Input AliRerankInput `json:"input"`
+ Parameters AliRerankParameters `json:"parameters,omitempty"`
+}
+
+type AliRerankResponse struct {
+ Output struct {
+ Results []dto.RerankResponseResult `json:"results"`
+ } `json:"output"`
+ Usage AliUsage `json:"usage"`
+ RequestId string `json:"request_id"`
+ AliError
+}
diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go
new file mode 100644
index 00000000..c9ae066a
--- /dev/null
+++ b/relay/channel/ali/rerank.go
@@ -0,0 +1,83 @@
+package ali
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
+ returnDocuments := request.ReturnDocuments
+ if returnDocuments == nil {
+ t := true
+ returnDocuments = &t
+ }
+ return &AliRerankRequest{
+ Model: request.Model,
+ Input: AliRerankInput{
+ Query: request.Query,
+ Documents: request.Documents,
+ },
+ Parameters: AliRerankParameters{
+ TopN: &request.TopN,
+ ReturnDocuments: returnDocuments,
+ },
+ }
+}
+
+func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ var aliResponse AliRerankResponse
+ err = json.Unmarshal(responseBody, &aliResponse)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ if aliResponse.Code != "" {
+ return &dto.OpenAIErrorWithStatusCode{
+ Error: dto.OpenAIError{
+ Message: aliResponse.Message,
+ Type: aliResponse.Code,
+ Param: aliResponse.RequestId,
+ Code: aliResponse.Code,
+ },
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+
+ usage := dto.Usage{
+ PromptTokens: aliResponse.Usage.TotalTokens,
+ CompletionTokens: 0,
+ TotalTokens: aliResponse.Usage.TotalTokens,
+ }
+ rerankResponse := dto.RerankResponse{
+ Results: aliResponse.Output.Results,
+ Usage: usage,
+ }
+
+ jsonResponse, err := json.Marshal(rerankResponse)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ return nil, &usage
+}
diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go
index 3fe893b3..2f1387c5 100644
--- a/relay/channel/ali/text.go
+++ b/relay/channel/ali/text.go
@@ -3,7 +3,6 @@ package ali
import (
"bufio"
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -11,6 +10,8 @@ import (
"one-api/relay/helper"
"one-api/service"
"strings"
+
+ "github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -27,9 +28,6 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
}
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
- if request.Model == "" {
- request.Model = "text-embedding-v1"
- }
return &AliEmbeddingRequest{
Model: request.Model,
Input: struct {
@@ -64,7 +62,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
}, nil
}
- fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
+ model := c.GetString("model")
+ if model == "" {
+ model = "text-embedding-v4"
+ }
+ fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
@@ -75,11 +77,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
return nil, &fullTextResponse.Usage
}
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
- Model: "text-embedding-v1",
+ Model: model,
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
@@ -94,12 +96,11 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
- content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index 03eff9cf..c3da5134 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -104,6 +104,105 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return targetConn, nil
}
+func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
+ pingerCtx, stopPinger := context.WithCancel(context.Background())
+
+ gopool.Go(func() {
+ defer func() {
+ // 增加panic恢复处理
+ if r := recover(); r != nil {
+ if common2.DebugEnabled {
+ println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
+ }
+ }
+ if common2.DebugEnabled {
+ println("SSE ping goroutine stopped.")
+ }
+ }()
+
+ if pingInterval <= 0 {
+ pingInterval = helper.DefaultPingInterval
+ }
+
+ ticker := time.NewTicker(pingInterval)
+ // 确保在任何情况下都清理ticker
+ defer func() {
+ ticker.Stop()
+ if common2.DebugEnabled {
+ println("SSE ping ticker stopped")
+ }
+ }()
+
+ var pingMutex sync.Mutex
+ if common2.DebugEnabled {
+ println("SSE ping goroutine started")
+ }
+
+ // 增加超时控制,防止goroutine长时间运行
+ maxPingDuration := 120 * time.Minute // 最大ping持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
+ for {
+ select {
+ // 发送 ping 数据
+ case <-ticker.C:
+ if err := sendPingData(c, &pingMutex); err != nil {
+ if common2.DebugEnabled {
+ println("SSE ping error, stopping goroutine:", err.Error())
+ }
+ return
+ }
+ // 收到退出信号
+ case <-pingerCtx.Done():
+ return
+ // request 结束
+ case <-c.Request.Context().Done():
+ return
+ // 超时保护,防止goroutine无限运行
+ case <-pingTimeout.C:
+ if common2.DebugEnabled {
+ println("SSE ping goroutine timeout, stopping")
+ }
+ return
+ }
+ }
+ })
+
+ return stopPinger
+}
+
+func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
+ // 增加超时控制,防止锁死等待
+ done := make(chan error, 1)
+ go func() {
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ err := helper.PingData(c)
+ if err != nil {
+ common2.LogError(c, "SSE ping error: "+err.Error())
+ done <- err
+ return
+ }
+
+ if common2.DebugEnabled {
+ println("SSE ping data sent.")
+ }
+ done <- nil
+ }()
+
+ // 设置发送ping数据的超时时间
+ select {
+ case err := <-done:
+ return err
+ case <-time.After(10 * time.Second):
+ return errors.New("SSE ping data send timeout")
+ case <-c.Request.Context().Done():
+ return errors.New("request context cancelled during ping")
+ }
+}
+
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
@@ -115,68 +214,36 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
} else {
client = service.GetHttpClient()
}
- // 流式请求 ping 保活
- var stopPinger func()
- generalSettings := operation_setting.GetGeneralSetting()
- pingEnabled := generalSettings.PingIntervalEnabled
- var pingerWg sync.WaitGroup
+
+ var stopPinger context.CancelFunc
if info.IsStream {
helper.SetEventStreamHeaders(c)
- pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
- var pingerCtx context.Context
- pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
-
- if pingEnabled {
- pingerWg.Add(1)
- gopool.Go(func() {
- defer pingerWg.Done()
- if pingInterval <= 0 {
- pingInterval = helper.DefaultPingInterval
- }
-
- ticker := time.NewTicker(pingInterval)
- defer ticker.Stop()
- var pingMutex sync.Mutex
- if common2.DebugEnabled {
- println("SSE ping goroutine started")
- }
-
- for {
- select {
- case <-ticker.C:
- pingMutex.Lock()
- err2 := helper.PingData(c)
- pingMutex.Unlock()
- if err2 != nil {
- common2.LogError(c, "SSE ping error: "+err.Error())
- return
- }
- if common2.DebugEnabled {
- println("SSE ping data sent.")
- }
- case <-pingerCtx.Done():
- if common2.DebugEnabled {
- println("SSE ping goroutine stopped.")
- }
- return
+ // 处理流式请求的 ping 保活
+ generalSettings := operation_setting.GetGeneralSetting()
+ if generalSettings.PingIntervalEnabled {
+ pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+ stopPinger = startPingKeepAlive(c, pingInterval)
+ // 使用defer确保在任何情况下都能停止ping goroutine
+ defer func() {
+ if stopPinger != nil {
+ stopPinger()
+ if common2.DebugEnabled {
+ println("SSE ping goroutine stopped by defer")
}
}
- })
+ }()
}
}
resp, err := client.Do(req)
- // request结束后停止ping
- if info.IsStream && pingEnabled {
- stopPinger()
- pingerWg.Wait()
- }
+
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
+
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go
index 37196fd8..64c7b747 100644
--- a/relay/channel/aws/constants.go
+++ b/relay/channel/aws/constants.go
@@ -11,6 +11,8 @@ var awsModelIDMap = map[string]string{
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
+ "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -41,6 +43,16 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"us": true,
+ "ap": true,
+ "eu": true,
+ },
+ "anthropic.claude-sonnet-4-20250514-v1:0": {
+ "us": true,
+ "ap": true,
+ "eu": true,
+ },
+ "anthropic.claude-opus-4-20250514-v1:0": {
+ "us": true,
},
}
diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go
index 62b06413..55b6c137 100644
--- a/relay/channel/baidu/relay-baidu.go
+++ b/relay/channel/baidu/relay-baidu.go
@@ -53,12 +53,11 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
- content, _ := json.Marshal(response.Result)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Result,
},
FinishReason: "stop",
}
diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go
index 77afe2dd..2b8a52a2 100644
--- a/relay/channel/baidu_v2/adaptor.go
+++ b/relay/channel/baidu_v2/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "strings"
"github.com/gin-gonic/gin"
)
@@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
+ if strings.HasSuffix(info.UpstreamModelName, "-search") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
+ request.Model = info.UpstreamModelName
+ toMap := request.ToMap()
+ toMap["web_search"] = map[string]any{
+ "enable": true,
+ "enable_citation": true,
+ "enable_trace": true,
+ "enable_status": false,
+ }
+ return toMap, nil
+ }
return request, nil
}
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index 4b071712..8389b9f1 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -38,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
- a.RequestMode = RequestModeMessage
- } else {
+ if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
+ } else {
+ a.RequestMode = RequestModeMessage
}
}
diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go
index d7e0c8e3..e0e3c421 100644
--- a/relay/channel/claude/constants.go
+++ b/relay/channel/claude/constants.go
@@ -13,6 +13,10 @@ var ModelList = []string{
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
+ "claude-sonnet-4-20250514",
+ "claude-sonnet-4-20250514-thinking",
+ "claude-opus-4-20250514",
+ "claude-opus-4-20250514-thinking",
}
var ChannelName = "claude"
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 95e7c4be..406ebc8a 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -48,9 +48,9 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
- prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
+ prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
} else if message.Role == "assistant" {
- prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
+ prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
@@ -155,15 +155,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() {
- content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
- fmtMessage.Content = content
+ fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
- content, _ := json.Marshal("...")
- fmtMessage.Content = content
+ fmtMessage.SetStringContent("...")
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = fmtMessage
@@ -397,12 +395,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
thinkingContent := ""
if reqMode == RequestModeCompletion {
- content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
@@ -457,6 +454,7 @@ type ClaudeResponseInfo struct {
Model string
ResponseText strings.Builder
Usage *dto.Usage
+ Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -464,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
- // message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
+
+ // message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
+ if claudeResponse.Delta.Thinking != "" {
+ claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+ }
} else if claudeResponse.Type == "message_delta" {
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ // 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
+ // 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+ // 判断是否完整
+ claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
@@ -509,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
+ FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
if requestMode == RequestModeCompletion {
- claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
- claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
- claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
- claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
- if claudeResponse.Usage.InputTokens > 0 {
- // 不叠加,只取最新的
- claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
- }
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
@@ -547,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
+
+ if requestMode == RequestModeCompletion {
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ if claudeInfo.Usage.PromptTokens == 0 {
+ //上游出错
+ }
+ if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+ if common.DebugEnabled {
+ common.SysError("claude response usage is not complete, maybe upstream error")
+ }
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+ }
+ }
+
if info.RelayFormat == relaycommon.RelayFormatClaude {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- // 说明流模式建立失败,可能为官方出错
- if claudeInfo.Usage.PromptTokens == 0 {
- //usage.PromptTokens = info.PromptTokens
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
- }
+ //
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- if claudeInfo.Usage.PromptTokens == 0 {
- //上游出错
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
- }
+
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
@@ -622,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if requestMode == RequestModeCompletion {
- completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
- }
+ completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
index a487429c..50d4928a 100644
--- a/relay/channel/cloudflare/relay_cloudflare.go
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
- usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
- usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
+ usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go
index 17b58dbc..29064242 100644
--- a/relay/channel/cohere/relay-cohere.go
+++ b/relay/channel/cohere/relay-cohere.go
@@ -3,7 +3,6 @@ package cohere
import (
"bufio"
"encoding/json"
- "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
@@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
})
if usage.PromptTokens == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}
@@ -195,11 +194,10 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
openaiResp.Model = modelName
openaiResp.Usage = usage
- content, _ := json.Marshal(cohereResp.Text)
openaiResp.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
- Message: dto.Message{Content: content, Role: "assistant"},
+ Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
},
}
diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go
index 4e9afa23..d5dc9a81 100644
--- a/relay/channel/coze/dto.go
+++ b/relay/channel/coze/dto.go
@@ -10,7 +10,7 @@ type CozeError struct {
type CozeEnterMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
- Content json.RawMessage `json:"content,omitempty"`
+ Content any `json:"content,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go
index 6db40213..ac76476f 100644
--- a/relay/channel/coze/relay-coze.go
+++ b/relay/channel/coze/relay-coze.go
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
var currentEvent string
var currentData string
- var usage dto.Usage
+ var usage = &dto.Usage{}
for scanner.Scan() {
line := scanner.Text()
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
- handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = ""
currentData = ""
}
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
// Last event
if currentEvent != "" && currentData != "" {
- handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
}
if err := scanner.Err(); err != nil {
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
helper.Done(c)
if usage.TotalTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
- return nil, &usage
+ return nil, usage
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go
index b58fbe53..115aed1b 100644
--- a/relay/channel/dify/relay-dify.go
+++ b/relay/channel/dify/relay-dify.go
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
return true
})
helper.Done(c)
- err := resp.Body.Close()
- if err != nil {
- // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- common.SysError("close_response_body_failed: " + err.Error())
- }
if usage.TotalTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return nil, usage
@@ -278,12 +271,11 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
- content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: difyResponse.Answer,
},
FinishReason: "stop",
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index c3c7b49d..968d9c9b 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -10,6 +10,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -71,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // suffix -thinking and -nothinking
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -165,6 +169,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+ if info.RelayMode == constant.RelayModeGemini {
+ if info.IsStream {
+ return GeminiTextGenerationStreamHandler(c, resp, info)
+ } else {
+ return GeminiTextGenerationHandler(c, resp, info)
+ }
+ }
+
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go
index 5d5c1287..b22e092a 100644
--- a/relay/channel/gemini/dto.go
+++ b/relay/channel/gemini/dto.go
@@ -1,11 +1,13 @@
package gemini
+import "encoding/json"
+
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
+ SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
+ GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
- SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
+ SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
type GeminiThinkingConfig struct {
@@ -22,19 +24,38 @@ type GeminiInlineData struct {
Data string `json:"data"`
}
+// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
+func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
+ type Alias GeminiInlineData // Use type alias to avoid recursion
+ var aux struct {
+ Alias
+ MimeTypeSnake string `json:"mime_type"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
+
+ // Prioritize snake_case if present
+ if aux.MimeTypeSnake != "" {
+ g.MimeType = aux.MimeTypeSnake
+ } else if aux.MimeType != "" { // Fallback to camelCase from Alias
+ g.MimeType = aux.MimeType
+ }
+ // g.Data would be populated by aux.Alias.Data
+ return nil
+}
+
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
-type GeminiFunctionResponseContent struct {
- Name string `json:"name"`
- Content any `json:"content"`
-}
-
type FunctionResponse struct {
- Name string `json:"name"`
- Response GeminiFunctionResponseContent `json:"response"`
+ Name string `json:"name"`
+ Response map[string]interface{} `json:"response"`
}
type GeminiPartExecutableCode struct {
@@ -54,6 +75,7 @@ type GeminiFileData struct {
type GeminiPart struct {
Text string `json:"text,omitempty"`
+ Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
@@ -62,6 +84,33 @@ type GeminiPart struct {
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
+// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
+func (p *GeminiPart) UnmarshalJSON(data []byte) error {
+ // Alias to avoid recursion during unmarshalling
+ type Alias GeminiPart
+ var aux struct {
+ Alias
+ InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ // Assign fields from alias
+ *p = GeminiPart(aux.Alias)
+
+ // Prioritize snake_case for InlineData if present
+ if aux.InlineDataSnake != nil {
+ p.InlineData = aux.InlineDataSnake
+ } else if aux.InlineData != nil { // Fallback to camelCase from Alias
+ p.InlineData = aux.InlineData
+ }
+ // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
+
+ return nil
+}
+
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
@@ -91,6 +140,7 @@ type GeminiChatGenerationConfig struct {
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+ SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type GeminiChatCandidate struct {
@@ -116,10 +166,16 @@ type GeminiChatResponse struct {
}
type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+}
+
+type GeminiPromptTokensDetails struct {
+ Modality string `json:"modality"`
+ TokenCount int `json:"tokenCount"`
}
// Imagen related structs
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
new file mode 100644
index 00000000..39757cef
--- /dev/null
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -0,0 +1,146 @@
+package gemini
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+ // 读取响应体
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ }
+
+ if common.DebugEnabled {
+ println(string(responseBody))
+ }
+
+ // 解析为 Gemini 原生响应格式
+ var geminiResponse GeminiChatResponse
+ err = common.DecodeJson(responseBody, &geminiResponse)
+ if err != nil {
+ return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ // 计算使用量(基于 UsageMetadata)
+ usage := dto.Usage{
+ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
+ CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
+ TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
+ }
+
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+
+ // 直接返回 Gemini 原生格式的 JSON 响应
+ jsonResponse, err := json.Marshal(geminiResponse)
+ if err != nil {
+ return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ // 设置响应头并写入响应
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
+ }
+
+ return &usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+ var usage = &dto.Usage{}
+ var imageCount int
+
+ helper.SetEventStreamHeaders(c)
+
+ responseText := strings.Builder{}
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var geminiResponse GeminiChatResponse
+ err := common.DecodeJsonStr(data, &geminiResponse)
+ if err != nil {
+ common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ return false
+ }
+
+ // 统计图片数量
+ for _, candidate := range geminiResponse.Candidates {
+ for _, part := range candidate.Content.Parts {
+ if part.InlineData != nil && part.InlineData.MimeType != "" {
+ imageCount++
+ }
+ if part.Text != "" {
+ responseText.WriteString(part.Text)
+ }
+ }
+ }
+
+ // 更新使用量统计
+ if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+ usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+ usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
+ usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+ }
+
+ // 直接发送 GeminiChatResponse 响应
+ err = helper.StringData(c, data)
+ if err != nil {
+ common.LogError(c, err.Error())
+ }
+
+ return true
+ })
+
+ if imageCount != 0 {
+ if usage.CompletionTokens == 0 {
+ usage.CompletionTokens = imageCount * 258
+ }
+ }
+
+ // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
+ if usage.CompletionTokens == 0 {
+ str := responseText.String()
+ if len(str) > 0 {
+ usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ // 空补全,不需要使用量
+ usage = &dto.Usage{}
+ }
+ }
+
+ // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
+ //helper.Done(c)
+
+ return usage, nil
+}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index ae9a3b7b..18edfd04 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -12,12 +12,72 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
+ "strconv"
"strings"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
+var geminiSupportedMimeTypes = map[string]bool{
+ "application/pdf": true,
+ "audio/mpeg": true,
+ "audio/mp3": true,
+ "audio/wav": true,
+ "image/png": true,
+ "image/jpeg": true,
+ "text/plain": true,
+ "video/mov": true,
+ "video/mpeg": true,
+ "video/mp4": true,
+ "video/mpg": true,
+ "video/avi": true,
+ "video/wmv": true,
+ "video/mpegps": true,
+ "video/flv": true,
+}
+
+// Gemini 允许的思考预算范围
+const (
+ pro25MinBudget = 128
+ pro25MaxBudget = 32768
+ flash25MaxBudget = 24576
+ flash25LiteMinBudget = 512
+ flash25LiteMaxBudget = 24576
+)
+
+// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
+func clampThinkingBudget(modelName string, budget int) int {
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+ is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+
+ if is25FlashLite {
+ if budget < flash25LiteMinBudget {
+ return flash25LiteMinBudget
+ }
+ if budget > flash25LiteMaxBudget {
+ return flash25LiteMaxBudget
+ }
+ } else if isNew25Pro {
+ if budget < pro25MinBudget {
+ return pro25MinBudget
+ }
+ if budget > pro25MaxBudget {
+ return pro25MaxBudget
+ }
+ } else { // 其他模型
+ if budget < 0 {
+ return 0
+ }
+ if budget > flash25MaxBudget {
+ return flash25MaxBudget
+ }
+ }
+ return budget
+}
+
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
@@ -39,18 +99,54 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
- budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
- if budgetTokens == 0 || budgetTokens > 24576 {
- budgetTokens = 24576
+ modelName := info.UpstreamModelName
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.SplitN(modelName, "-thinking-", 2)
+ if len(parts) == 2 && parts[1] != "" {
+ if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
+ clampedBudget := clampThinkingBudget(modelName, budgetTokens)
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(clampedBudget),
+ IncludeThoughts: true,
+ }
+ }
}
- geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(int(budgetTokens)),
- IncludeThoughts: true,
+ } else if strings.HasSuffix(modelName, "-thinking") {
+ unsupportedModels := []string{
+ "gemini-2.5-pro-preview-05-06",
+ "gemini-2.5-pro-preview-03-25",
}
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
- geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(0),
+ isUnsupported := false
+ for _, unsupportedModel := range unsupportedModels {
+ if strings.HasPrefix(modelName, unsupportedModel) {
+ isUnsupported = true
+ break
+ }
+ }
+
+ if isUnsupported {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ } else {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
+ budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
+ clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
+ geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-nothinking") {
+ if !isNew25Pro {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(0),
+ }
}
}
}
@@ -112,12 +208,6 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
- } else if textRequest.Functions != nil {
- //geminiRequest.Tools = []GeminiChatTool{
- // {
- // FunctionDeclarations: textRequest.Functions,
- // },
- //}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -148,17 +238,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
- content := common.StrToMap(message.StringContent())
+ var contentMap map[string]interface{}
+ contentStr := message.StringContent()
+
+ // 1. 尝试解析为 JSON 对象
+ if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
+ // 2. 如果失败,尝试解析为 JSON 数组
+ var contentSlice []interface{}
+ if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
+ // 如果是数组,包装成对象
+ contentMap = map[string]interface{}{"result": contentSlice}
+ } else {
+ // 3. 如果再次失败,作为纯文本处理
+ contentMap = map[string]interface{}{"content": contentStr}
+ }
+ }
+
functionResp := &FunctionResponse{
- Name: name,
- Response: GeminiFunctionResponseContent{
- Name: name,
- Content: content,
- },
- }
- if content == nil {
- functionResp.Response.Content = message.StringContent()
+ Name: name,
+ Response: contentMap,
}
+
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
})
@@ -208,14 +308,21 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
// 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
- // 是url,获取图片的类型和base64编码的数据
+ // 是url,获取文件的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
if err != nil {
- return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+ return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
}
+
+ // 校验 MimeType 是否在 Gemini 支持的白名单中
+ if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
+ url := part.GetImageMedia().Url
+ return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
+ }
+
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
- MimeType: fileData.MimeType,
+ MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
Data: fileData.Base64Data,
},
})
@@ -249,13 +356,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
- format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
+ base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
- MimeType: format,
+ MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String,
},
})
@@ -268,7 +375,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if content.Role == "assistant" {
content.Role = "model"
}
- geminiRequest.Contents = append(geminiRequest.Contents, content)
+ if len(content.Parts) > 0 {
+ geminiRequest.Contents = append(geminiRequest.Contents, content)
+ }
}
if len(system_content) > 0 {
@@ -284,100 +393,126 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return &geminiRequest, nil
}
+// Helper function to get a list of supported MIME types for error messages
+func getSupportedMimeTypesList() []string {
+ keys := make([]string, 0, len(geminiSupportedMimeTypes))
+ for k := range geminiSupportedMimeTypes {
+ keys = append(keys, k)
+ }
+ return keys
+}
+
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
- paramMap, ok := params.(map[string]interface{})
- if !ok {
- // Not a map, return as is (e.g., could be an array or primitive)
- return params
- }
+ switch v := params.(type) {
+ case map[string]interface{}:
+ // Create a copy to avoid modifying the original
+ cleanedMap := make(map[string]interface{})
+ for k, val := range v {
+ cleanedMap[k] = val
+ }
- // Create a copy to avoid modifying the original
- cleanedMap := make(map[string]interface{})
- for k, v := range paramMap {
- cleanedMap[k] = v
- }
+ // Remove unsupported root-level fields
+ delete(cleanedMap, "default")
+ delete(cleanedMap, "exclusiveMaximum")
+ delete(cleanedMap, "exclusiveMinimum")
+ delete(cleanedMap, "$schema")
+ delete(cleanedMap, "additionalProperties")
- // Remove unsupported root-level fields
- delete(cleanedMap, "default")
- delete(cleanedMap, "exclusiveMaximum")
- delete(cleanedMap, "exclusiveMinimum")
- delete(cleanedMap, "$schema")
- delete(cleanedMap, "additionalProperties")
-
- // Clean properties
- if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
- cleanedProps := make(map[string]interface{})
- for propName, propValue := range props {
- propMap, ok := propValue.(map[string]interface{})
- if !ok {
- cleanedProps[propName] = propValue // Keep non-map properties
- continue
- }
-
- // Create a copy of the property map
- cleanedPropMap := make(map[string]interface{})
- for k, v := range propMap {
- cleanedPropMap[k] = v
- }
-
- // Remove unsupported fields
- delete(cleanedPropMap, "default")
- delete(cleanedPropMap, "exclusiveMaximum")
- delete(cleanedPropMap, "exclusiveMinimum")
- delete(cleanedPropMap, "$schema")
- delete(cleanedPropMap, "additionalProperties")
-
- // Check and clean 'format' for string types
- if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
- if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
- if formatValue != "enum" && formatValue != "date-time" {
- delete(cleanedPropMap, "format")
- }
+ // Check and clean 'format' for string types
+ if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
+ if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
+ if formatValue != "enum" && formatValue != "date-time" {
+ delete(cleanedMap, "format")
}
}
+ }
- // Recursively clean nested properties within this property if it's an object/array
- // Check the type before recursing
- if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
- cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
- } else {
- cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
+ // Clean properties
+ if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
+ cleanedProps := make(map[string]interface{})
+ for propName, propValue := range props {
+ cleanedProps[propName] = cleanFunctionParameters(propValue)
}
-
+ cleanedMap["properties"] = cleanedProps
}
- cleanedMap["properties"] = cleanedProps
- }
- // Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
- if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
- cleanedMap["items"] = cleanFunctionParameters(items)
- }
- // Also handle items if it's an array of schemas
- if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
- cleanedItemsArray := make([]interface{}, len(itemsArray))
- for i, item := range itemsArray {
- cleanedItemsArray[i] = cleanFunctionParameters(item)
+ // Recursively clean items in arrays
+ if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
+ cleanedMap["items"] = cleanFunctionParameters(items)
}
- cleanedMap["items"] = cleanedItemsArray
- }
-
- // Recursively clean other schema composition keywords if necessary
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
- if nested, ok := cleanedMap[field].([]interface{}); ok {
- cleanedNested := make([]interface{}, len(nested))
- for i, item := range nested {
- cleanedNested[i] = cleanFunctionParameters(item)
+ // Also handle items if it's an array of schemas
+ if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
+ cleanedItemsArray := make([]interface{}, len(itemsArray))
+ for i, item := range itemsArray {
+ cleanedItemsArray[i] = cleanFunctionParameters(item)
}
- cleanedMap[field] = cleanedNested
+ cleanedMap["items"] = cleanedItemsArray
}
- }
- return cleanedMap
+ // Recursively clean other schema composition keywords
+ for _, field := range []string{"allOf", "anyOf", "oneOf"} {
+ if nested, ok := cleanedMap[field].([]interface{}); ok {
+ cleanedNested := make([]interface{}, len(nested))
+ for i, item := range nested {
+ cleanedNested[i] = cleanFunctionParameters(item)
+ }
+ cleanedMap[field] = cleanedNested
+ }
+ }
+
+ // Recursively clean patternProperties
+ if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
+ cleanedPatternProps := make(map[string]interface{})
+ for pattern, schema := range patternProps {
+ cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
+ }
+ cleanedMap["patternProperties"] = cleanedPatternProps
+ }
+
+ // Recursively clean definitions
+ if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
+ cleanedDefinitions := make(map[string]interface{})
+ for defName, defSchema := range definitions {
+ cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
+ }
+ cleanedMap["definitions"] = cleanedDefinitions
+ }
+
+ // Recursively clean $defs (newer JSON Schema draft)
+ if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
+ cleanedDefs := make(map[string]interface{})
+ for defName, defSchema := range defs {
+ cleanedDefs[defName] = cleanFunctionParameters(defSchema)
+ }
+ cleanedMap["$defs"] = cleanedDefs
+ }
+
+ // Clean conditional keywords
+ for _, field := range []string{"if", "then", "else", "not"} {
+ if nested, ok := cleanedMap[field]; ok {
+ cleanedMap[field] = cleanFunctionParameters(nested)
+ }
+ }
+
+ return cleanedMap
+
+ case []interface{}:
+ // Handle arrays of schemas
+ cleanedArray := make([]interface{}, len(v))
+ for i, item := range v {
+ cleanedArray[i] = cleanFunctionParameters(item)
+ }
+ return cleanedArray
+
+ default:
+ // Not a map or array, return as is (e.g., could be a primitive)
+ return params
+ }
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
@@ -512,21 +647,20 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
}
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
+func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
- content, _ := json.Marshal("")
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: "",
},
FinishReason: constant.FinishReasonStop,
}
@@ -539,6 +673,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
+ } else if part.Thought {
+ choice.Message.ReasoningContent = part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
@@ -556,7 +692,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
-
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
@@ -596,6 +731,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
var texts []string
isTools := false
+ isThought := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
@@ -620,6 +756,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
+ } else if part.Thought {
+ isThought = true
+ texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
@@ -632,7 +771,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
}
}
- choice.Delta.SetContentString(strings.Join(texts, "\n"))
+ if isThought {
+ choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
+ } else {
+ choice.Delta.SetContentString(strings.Join(texts, "\n"))
+ }
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
@@ -647,7 +790,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// responseText := ""
- id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
var imageCount int
@@ -672,6 +815,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
}
err = helper.ObjectData(c, response)
if err != nil {
@@ -716,8 +866,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
+ if common.DebugEnabled {
+ println(string(responseBody))
+ }
var geminiResponse GeminiChatResponse
- err = json.Unmarshal(responseBody, &geminiResponse)
+ err = common.DecodeJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -732,7 +885,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
StatusCode: resp.StatusCode,
}, nil
}
- fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
+ fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
@@ -743,6 +896,14 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go
index 75272e34..e26c6101 100644
--- a/relay/channel/mistral/text.go
+++ b/relay/channel/mistral/text.go
@@ -1,13 +1,55 @@
package mistral
import (
+ "one-api/common"
"one-api/dto"
+ "regexp"
)
+var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
+
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
+ idMap := make(map[string]string)
for _, message := range request.Messages {
+ // 1. tool_calls.id
+ toolCalls := message.ParseToolCalls()
+ if toolCalls != nil {
+ for i := range toolCalls {
+ if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
+ if newId, ok := idMap[toolCalls[i].ID]; ok {
+ toolCalls[i].ID = newId
+ } else {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[toolCalls[i].ID] = newId
+ toolCalls[i].ID = newId
+ }
+ }
+ }
+ }
+ message.SetToolCalls(toolCalls)
+ }
+
+ // 2. tool_call_id
+ if message.ToolCallId != "" {
+ if newId, ok := idMap[message.ToolCallId]; ok {
+ message.ToolCallId = newId
+ } else {
+ if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[message.ToolCallId] = newId
+ message.ToolCallId = newId
+ }
+ }
+ }
+ }
+
mediaMessages := message.ParseContent()
+ if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
+ mediaMessages = []dto.MediaContent{}
+ }
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index f0cf073f..8358f3e2 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
+
+ // 特殊处理 responses API
+ if info.RelayMode == constant.RelayModeResponses {
+ requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
+ return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+ }
+
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 86c47a15..71590cd6 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -15,6 +15,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"os"
+ "path/filepath"
"strings"
"github.com/bytedance/gopkg/util/gopool"
@@ -180,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
if !containStreamUsage {
- usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -215,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
StatusCode: resp.StatusCode,
}, nil
}
-
+
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
@@ -224,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
- ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+ ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -273,36 +274,25 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
+ // the status code has been judged before, if there is a body reading failure,
+ // it should be regarded as a non-recoverable error, so it should not return err for external retry.
+ // Analogous to nginx's load balancing, it will only retry if it can't be requested or
+ // if the upstream returns a specific status code, once the upstream has already written the header,
+ // the subsequent failure of the response body should be regarded as a non-recoverable error,
+ // and can be terminated directly.
+ defer resp.Body.Close()
+ usage := &dto.Usage{}
+ usage.PromptTokens = info.PromptTokens
+ usage.TotalTokens = info.PromptTokens
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
+ c.Writer.WriteHeaderNow()
+ _, err := io.Copy(c.Writer, resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+ common.LogError(c, err.Error())
}
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
-
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.TotalTokens = info.PromptTokens
return nil, usage
}
@@ -356,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
-
+ ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
+ defer reqFp.Close()
- tmpFp, err := os.CreateTemp("", "audio-*")
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
@@ -376,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
return 0, errors.WithStack(err)
}
- duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+ duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
index 1d1e060e..da9382c3 100644
--- a/relay/channel/openai/relay_responses.go
+++ b/relay/channel/openai/relay_responses.go
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
- completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+ completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go
new file mode 100644
index 00000000..607f495b
--- /dev/null
+++ b/relay/channel/openrouter/dto.go
@@ -0,0 +1,9 @@
+package openrouter
+
+type RequestReasoning struct {
+ // One of the following (not both):
+ Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
+ MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
+ // Optional: Default is false. All models support this.
+ Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
+}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 3a06e7ee..aee4a307 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index c8e337de..9d3dbd67 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -2,7 +2,6 @@ package palm
import (
"encoding/json"
- "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -45,12 +44,11 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
- content, _ := json.Marshal(candidate.Content)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: candidate.Content,
},
FinishReason: "stop",
}
@@ -74,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -157,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
+ completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
new file mode 100644
index 00000000..9ea58728
--- /dev/null
+++ b/relay/channel/task/kling/adaptor.go
@@ -0,0 +1,312 @@
+package kling
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+ Prompt string `json:"prompt,omitempty"`
+ Image string `json:"image,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Duration string `json:"duration,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty"`
+ Model string `json:"model,omitempty"`
+ ModelName string `json:"model_name,omitempty"`
+ CfgScale float64 `json:"cfg_scale,omitempty"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ TaskID string `json:"task_id"`
+ } `json:"data"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.BaseUrl
+
+ // apiKey format: "access_key,secret_key"
+ keyParts := strings.Split(info.ApiKey, ",")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := "generate"
+ info.Action = action
+
+ var req SubmitReq
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("kling_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ token, err := a.createJWTToken()
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+ return nil
+}
+
+// BuildRequestBody converts request into Kling specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("kling_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(SubmitReq)
+
+ body := a.convertToRequestPayload(&req)
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Attempt Kling response parse first.
+ var kResp responsePayload
+ if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
+ c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
+ return kResp.Data.TaskID, responseBody, nil
+ }
+
+ // Fallback generic task response.
+ var generic dto.TaskResponse[string]
+ if err := json.Unmarshal(responseBody, &generic); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if !generic.IsSuccess() {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
+ return generic.Data, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+ url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ token, err := a.createJWTTokenWithKey(key)
+ if err != nil {
+ token = key
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ req = req.WithContext(ctx)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "kling"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
+ r := &requestPayload{
+ Prompt: req.Prompt,
+ Image: req.Image,
+ Mode: defaultString(req.Mode, "std"),
+ Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+ AspectRatio: a.getAspectRatio(req.Size),
+ Model: req.Model,
+ ModelName: req.Model,
+ CfgScale: 0.5,
+ }
+ if r.Model == "" {
+ r.Model = "kling-v1"
+ r.ModelName = "kling-v1"
+ }
+ return r
+}
+
+func (a *TaskAdaptor) getAspectRatio(size string) string {
+ switch size {
+ case "1024x1024", "512x512":
+ return "1:1"
+ case "1280x720", "1920x1080":
+ return "16:9"
+ case "720x1280", "1080x1920":
+ return "9:16"
+ default:
+ return "1:1"
+ }
+}
+
+func defaultString(s, def string) string {
+ if strings.TrimSpace(s) == "" {
+ return def
+ }
+ return s
+}
+
+func defaultInt(v int, def int) int {
+ if v == 0 {
+ return def
+ }
+ return v
+}
+
+// ============================
+// JWT helpers
+// ============================
+
+func (a *TaskAdaptor) createJWTToken() (string, error) {
+ return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+ parts := strings.Split(apiKey, ",")
+ if len(parts) != 2 {
+ return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+ }
+ return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
+ if accessKey == "" || secretKey == "" {
+ return "", fmt.Errorf("access key and secret key are required")
+ }
+ now := time.Now().Unix()
+ claims := jwt.MapClaims{
+ "iss": accessKey,
+ "exp": now + 1800, // 30 minutes
+ "nbf": now - 5,
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ token.Header["typ"] = "JWT"
+ return token.SignedString([]byte(secretKey))
+}
+
+// ParseResultUrl 提取视频任务结果的 url
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+ data, ok := resp["data"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("data field not found or invalid")
+ }
+ taskResult, ok := data["task_result"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("task_result field not found or invalid")
+ }
+ videos, ok := taskResult["videos"].([]interface{})
+ if !ok || len(videos) == 0 {
+ return "", fmt.Errorf("videos field not found or empty")
+ }
+ video, ok := videos[0].(map[string]interface{})
+ if !ok {
+ return "", fmt.Errorf("video item invalid")
+ }
+ url, ok := video["url"].(string)
+ if !ok || url == "" {
+ return "", fmt.Errorf("url field not found or invalid")
+ }
+ return url, nil
+}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index 03d60516..f7042348 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int
}
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+ return "", nil // todo implement this method if needed
+}
+
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index 44718a25..7ea3aae7 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}
diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go
index 5630650f..1446e06e 100644
--- a/relay/channel/tencent/relay-tencent.go
+++ b/relay/channel/tencent/relay-tencent.go
@@ -56,12 +56,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
},
}
if len(response.Choices) > 0 {
- content, _ := json.Marshal(response.Choices[0].Messages.Content)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go
index 7daf9a61..e568f651 100644
--- a/relay/channel/vertex/adaptor.go
+++ b/relay/channel/vertex/adaptor.go
@@ -12,6 +12,7 @@ import (
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
"one-api/setting/model_setting"
"strings"
@@ -31,6 +32,8 @@ var claudeModelMap = map[string]string{
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
+ "claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
+ "claude-opus-4-20250514": "claude-opus-4@20250514",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -80,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // suffix -thinking and -nothinking
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -93,14 +99,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "generateContent"
}
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- info.UpstreamModelName,
- suffix,
- ), nil
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+ adc.ProjectID,
+ info.UpstreamModelName,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ info.UpstreamModelName,
+ suffix,
+ ), nil
+ }
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
@@ -111,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- model,
- suffix,
- ), nil
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+ adc.ProjectID,
+ model,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ model,
+ suffix,
+ ), nil
+ }
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
@@ -190,7 +214,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
- err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+ if info.RelayMode == constant.RelayModeGemini {
+ usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
+ } else {
+ err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+ }
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
}
@@ -199,7 +227,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
- err, usage = gemini.GeminiChatHandler(c, resp, info)
+ if info.RelayMode == constant.RelayModeGemini {
+ usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
+ } else {
+ err, usage = gemini.GeminiChatHandler(c, resp, info)
+ }
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info)
}
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
index a4a48ee9..78233934 100644
--- a/relay/channel/volcengine/adaptor.go
+++ b/relay/channel/volcengine/adaptor.go
@@ -1,15 +1,19 @@
package volcengine
import (
+ "bytes"
"errors"
"fmt"
"io"
+ "mime/multipart"
"net/http"
+ "net/textproto"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "path/filepath"
"strings"
"github.com/gin-gonic/gin"
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
+ switch info.RelayMode {
+ case constant.RelayModeImagesEdits:
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ writer.WriteField("model", request.Model)
+ // 获取所有表单字段
+ formData := c.Request.PostForm
+ // 遍历表单字段并打印输出
+ for key, values := range formData {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
+ }
+
+ // Parse the multipart form to handle both single image and multiple images
+ if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+ return nil, errors.New("failed to parse multipart form")
+ }
+
+ if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+ // Check if "image" field exists in any form, including array notation
+ var imageFiles []*multipart.FileHeader
+ var exists bool
+
+ // First check for standard "image" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+ // If not found, check for "image[]" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+ // If still not found, iterate through all fields to find any that start with "image["
+ foundArrayImages := false
+ for fieldName, files := range c.Request.MultipartForm.File {
+ if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+ foundArrayImages = true
+ for _, file := range files {
+ imageFiles = append(imageFiles, file)
+ }
+ }
+ }
+
+ // If no image fields found at all
+ if !foundArrayImages && (len(imageFiles) == 0) {
+ return nil, errors.New("image is required")
+ }
+ }
+ }
+
+ // Process all image files
+ for i, fileHeader := range imageFiles {
+ file, err := fileHeader.Open()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+ }
+ defer file.Close()
+
+ // If multiple images, use image[] as the field name
+ fieldName := "image"
+ if len(imageFiles) > 1 {
+ fieldName = "image[]"
+ }
+
+ // Determine MIME type based on file extension
+ mimeType := detectImageMimeType(fileHeader.Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+ h.Set("Content-Type", mimeType)
+
+ part, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+ }
+
+ if _, err := io.Copy(part, file); err != nil {
+ return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+ }
+ }
+
+ // Handle mask file if present
+ if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+ maskFile, err := maskFiles[0].Open()
+ if err != nil {
+ return nil, errors.New("failed to open mask file")
+ }
+ defer maskFile.Close()
+
+ // Determine MIME type for mask file
+ mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+ h.Set("Content-Type", mimeType)
+
+ maskPart, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, errors.New("create form file failed for mask")
+ }
+
+ if _, err := io.Copy(maskPart, maskFile); err != nil {
+ return nil, errors.New("copy mask file failed")
+ }
+ }
+ } else {
+ return nil, errors.New("no multipart form data found")
+ }
+
+ // 关闭 multipart 编写器以设置分界线
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return bytes.NewReader(requestBody.Bytes()), nil
+
+ default:
+ return request, nil
+ }
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+ ext := strings.ToLower(filepath.Ext(filename))
+ switch ext {
+ case ".jpg", ".jpeg":
+ return "image/jpeg"
+ case ".png":
+ return "image/png"
+ case ".webp":
+ return "image/webp"
+ default:
+ // Try to detect from extension if possible
+ if strings.HasPrefix(ext, ".jp") {
+ return "image/jpeg"
+ }
+ // Default to png as a fallback
+ return "image/png"
+ }
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+ case constant.RelayModeImagesGenerations:
+ return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info)
+ case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+ err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
}
return
}
diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go
index e019c2dc..408160fb 100644
--- a/relay/channel/xai/text.go
+++ b/relay/channel/xai/text.go
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
})
if !containStreamUsage {
- usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go
index 15d33510..c6ef722c 100644
--- a/relay/channel/xunfei/relay-xunfei.go
+++ b/relay/channel/xunfei/relay-xunfei.go
@@ -61,12 +61,11 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse
},
}
}
- content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: constant.FinishReasonStop,
}
diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go
index b0cac858..744538e3 100644
--- a/relay/channel/zhipu/relay-zhipu.go
+++ b/relay/channel/zhipu/relay-zhipu.go
@@ -108,12 +108,11 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
- content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
openaiChoice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: choice.Role,
- Content: content,
+ Content: strings.Trim(choice.Content, "\""),
},
FinishReason: "",
}
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index fb68a88a..42139ddf 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo.IsStream = true
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- textRequest.Model = relayInfo.UpstreamModelName
-
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
@@ -98,7 +96,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
// BudgetTokens 为 max_tokens 的 80%
textRequest.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
+ return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index f4fc3c1e..3759c363 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
}
const (
- RelayFormatOpenAI = "openai"
- RelayFormatClaude = "claude"
- RelayFormatGemini = "gemini"
+ RelayFormatOpenAI = "openai"
+ RelayFormatClaude = "claude"
+ RelayFormatGemini = "gemini"
+ RelayFormatOpenAIResponses = "openai_responses"
+ RelayFormatOpenAIAudio = "openai_audio"
+ RelayFormatOpenAIImage = "openai_image"
+ RelayFormatRerank = "rerank"
+ RelayFormatEmbedding = "embedding"
)
type RerankerInfo struct {
@@ -61,6 +66,7 @@ type RelayInfo struct {
TokenKey string
UserId int
Group string
+ UserGroup string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
@@ -142,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
+ info.RelayFormat = RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
@@ -149,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info
}
+func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatOpenAIAudio
+ return info
+}
+
+func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatEmbedding
+ return info
+}
+
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses
+ info.RelayFormat = RelayFormatOpenAIResponses
+
+ info.SupportStreamOptions = false
+
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
@@ -174,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
return info
}
+func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatGemini
+ info.ShouldIncludeUsage = false
+ return info
+}
+
+func GenRelayInfoImage(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatOpenAIImage
+ return info
+}
+
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
@@ -204,6 +240,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenKey: tokenKey,
UserId: userId,
Group: group,
+ UserGroup: c.GetString(constant.ContextKeyUserGroup),
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
@@ -241,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
- // responses 模式不支持 StreamOptions
- if relayconstant.RelayModeResponses == info.RelayMode {
- info.SupportStreamOptions = false
- }
return info
}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 4454e815..02a286e2 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -38,11 +38,16 @@ const (
RelayModeSunoFetchByID
RelayModeSunoSubmit
+ RelayModeKlingFetchByID
+ RelayModeKlingSubmit
+
RelayModeRerank
RelayModeResponses
RelayModeRealtime
+
+ RelayModeGemini
)
func Path2RelayMode(path string) int {
@@ -75,6 +80,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
+ } else if strings.HasPrefix(path, "/v1beta/models") {
+ relayMode = RelayModeGemini
}
return relayMode
}
@@ -129,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
}
return relayMode
}
+
+func Path2RelayKling(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
+ relayMode = RelayModeKlingSubmit
+ } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+ relayMode = RelayModeKlingFetchByID
+ }
+ return relayMode
+}
diff --git a/relay/relay_embedding.go b/relay/embedding_handler.go
similarity index 94%
rename from relay/relay_embedding.go
rename to relay/embedding_handler.go
index b4909849..849c70da 100644
--- a/relay/relay_embedding.go
+++ b/relay/embedding_handler.go
@@ -15,7 +15,7 @@ import (
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
- token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
+ token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token
}
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- relayInfo := relaycommon.GenRelayInfo(c)
+ relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- embeddingRequest.Model = relayInfo.UpstreamModelName
-
promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
new file mode 100644
index 00000000..14d58cc5
--- /dev/null
+++ b/relay/gemini_handler.go
@@ -0,0 +1,190 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel/gemini"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
+ request := &gemini.GeminiChatRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if len(request.Contents) == 0 {
+ return nil, errors.New("contents is required")
+ }
+ return request, nil
+}
+
+// 流模式
+// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
+func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+ if c.Query("alt") == "sse" {
+ relayInfo.IsStream = true
+ }
+
+ // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
+ // relayInfo.IsStream = true
+ // }
+}
+
+func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
+ var inputTexts []string
+ for _, content := range textRequest.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+ if len(inputTexts) == 0 {
+ return nil, nil
+ }
+
+ sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
+ return sensitiveWords, err
+}
+
+func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
+ // 计算输入 token 数量
+ var inputTexts []string
+ for _, content := range req.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+
+ inputText := strings.Join(inputTexts, "\n")
+ inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
+ info.PromptTokens = inputTokens
+ return inputTokens
+}
+
+func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
+ req, err := getAndValidateGeminiRequest(c)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
+ return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
+ }
+
+ relayInfo := relaycommon.GenRelayInfoGemini(c)
+
+ // 检查 Gemini 流式模式
+ checkGeminiStreamMode(c, relayInfo)
+
+ if setting.ShouldCheckPromptSensitive() {
+ sensitiveWords, err := checkGeminiInputSensitive(req)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
+ return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
+ }
+ }
+
+ // model mapped 模型映射
+ err = helper.ModelMappedHelper(c, relayInfo, req)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
+ }
+
+ if value, exists := c.Get("prompt_tokens"); exists {
+ promptTokens := value.(int)
+ relayInfo.SetPromptTokens(promptTokens)
+ } else {
+ promptTokens := getGeminiInputTokens(req, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
+ }
+ c.Set("prompt_tokens", promptTokens)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
+ }
+
+ // pre consume quota
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if openaiErr != nil {
+ return openaiErr
+ }
+ defer func() {
+ if openaiErr != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+ }
+
+ adaptor.Init(relayInfo)
+
+ // Clean up empty system instruction
+ if req.SystemInstructions != nil {
+ hasContent := false
+ for _, part := range req.SystemInstructions.Parts {
+ if part.Text != "" {
+ hasContent = true
+ break
+ }
+ }
+ if !hasContent {
+ req.SystemInstructions = nil
+ }
+ }
+
+ requestBody, err := json.Marshal(req)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
+ }
+
+ if common.DebugEnabled {
+ println("Gemini request body: %s", string(requestBody))
+ }
+
+ resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
+ if err != nil {
+ common.LogError(c, "Do gemini request failed: "+err.Error())
+ return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ openaiErr = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+ return openaiErr
+ }
+ }
+
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+ if openaiErr != nil {
+ service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+ return openaiErr
+ }
+
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ return nil
+}
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
index 9bf67c03..c1735149 100644
--- a/relay/helper/model_mapped.go
+++ b/relay/helper/model_mapped.go
@@ -4,12 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
+ common2 "one-api/common"
+ "one-api/dto"
"one-api/relay/common"
"github.com/gin-gonic/gin"
)
-func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel
}
}
+ if request != nil {
+ switch info.RelayFormat {
+ case common.RelayFormatGemini:
+ // Gemini 模型映射
+ case common.RelayFormatClaude:
+ if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
+ claudeRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIResponses:
+ if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
+ openAIResponsesRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIAudio:
+ if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
+ openAIAudioRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIImage:
+ if imageRequest, ok := request.(*dto.ImageRequest); ok {
+ imageRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatRerank:
+ if rerankRequest, ok := request.(*dto.RerankRequest); ok {
+ rerankRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatEmbedding:
+ if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
+ embeddingRequest.Model = info.UpstreamModelName
+ }
+ default:
+ if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
+ openAIRequest.Model = info.UpstreamModelName
+ } else {
+ common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
+ }
+ }
+ }
return nil
}
diff --git a/relay/helper/price.go b/relay/helper/price.go
index 89efa1da..1ee2767e 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -2,14 +2,19 @@ package helper
import (
"fmt"
- "github.com/gin-gonic/gin"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
- "one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
+type GroupRatioInfo struct {
+ GroupRatio float64
+ GroupSpecialRatio float64
+}
+
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -17,18 +22,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
- GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
+ GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+}
+
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
+ groupRatioInfo := GroupRatioInfo{
+ GroupRatio: 1.0, // default ratio
+ GroupSpecialRatio: -1,
+ }
+
+ // check auto group
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("final group: %s", autoGroup))
+ }
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ // check user group special ratio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
+ if ok {
+ // user group special ratio
+ groupRatioInfo.GroupSpecialRatio = userGroupRatio
+ groupRatioInfo.GroupRatio = userGroupRatio
+ } else {
+ // normal group ratio
+ groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group)
+ }
+
+ return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
- modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
- groupRatio := setting.GetGroupRatio(info.Group)
+ modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
+
+ groupRatioInfo := HandleGroupRatio(c, info)
+
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -41,7 +78,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
preConsumedTokens = promptTokens + maxTokens
}
var success bool
- modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
+ modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
@@ -54,21 +91,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
}
}
- completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
- cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
- cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
- imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
- ratio := modelRatio * groupRatio
+ completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
+ cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
+ cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
+ imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
+ ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
- GroupRatio: groupRatio,
+ GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
@@ -84,11 +121,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
func ContainPriceOrRatio(modelName string) bool {
- _, ok := operation_setting.GetModelPrice(modelName, false)
+ _, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {
return true
}
- _, ok = operation_setting.GetModelRatio(modelName)
+ _, ok = ratio_setting.GetModelRatio(modelName)
if ok {
return true
}
diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go
index c1bc0d6e..a69877e2 100644
--- a/relay/helper/stream_scanner.go
+++ b/relay/helper/stream_scanner.go
@@ -3,6 +3,7 @@ package helper
import (
"bufio"
"context"
+ "fmt"
"io"
"net/http"
"one-api/common"
@@ -19,8 +20,8 @@ import (
)
const (
- InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
- MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
+ InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
+ MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
@@ -30,7 +31,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
}
- defer resp.Body.Close()
+ // 确保响应体总是被关闭
+ defer func() {
+ if resp.Body != nil {
+ resp.Body.Close()
+ }
+ }()
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o") {
@@ -39,11 +45,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
var (
- stopChan = make(chan bool, 2)
+ stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
+ wg sync.WaitGroup // 用于等待所有 goroutine 退出
)
generalSettings := operation_setting.GetGeneralSetting()
@@ -57,13 +64,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
pingTicker = time.NewTicker(pingInterval)
}
+ // 改进资源清理,确保所有 goroutine 正确退出
defer func() {
+ // 通知所有 goroutine 停止
+ common.SafeSendBool(stopChan, true)
+
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
+
+ // 等待所有 goroutine 退出,最多等待5秒
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ common.LogError(c, "timeout waiting for goroutines to exit")
+ }
+
close(stopChan)
}()
+
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
@@ -73,35 +99,95 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
ctx = context.WithValue(ctx, "stop_chan", stopChan)
- // Handle ping data sending
+ // Handle ping data sending with improved error handling
if pingEnabled && pingTicker != nil {
+ wg.Add(1)
gopool.Go(func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
+ common.SafeSendBool(stopChan, true)
+ }
+ if common.DebugEnabled {
+ println("ping goroutine exited")
+ }
+ }()
+
+ // 添加超时保护,防止 goroutine 无限运行
+ maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
for {
select {
case <-pingTicker.C:
- writeMutex.Lock() // Lock before writing
- err := PingData(c)
- writeMutex.Unlock() // Unlock after writing
- if err != nil {
- common.LogError(c, "ping data error: "+err.Error())
- common.SafeSendBool(stopChan, true)
+ // 使用超时机制防止写操作阻塞
+ done := make(chan error, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- PingData(c)
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ common.LogError(c, "ping data error: "+err.Error())
+ return
+ }
+ if common.DebugEnabled {
+ println("ping data sent")
+ }
+ case <-time.After(10 * time.Second):
+ common.LogError(c, "ping data send timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
return
}
- if common.DebugEnabled {
- println("ping data sent")
- }
case <-ctx.Done():
- if common.DebugEnabled {
- println("ping data goroutine stopped")
- }
+ return
+ case <-stopChan:
+ return
+ case <-c.Request.Context().Done():
+ // 监听客户端断开连接
+ return
+ case <-pingTimeout.C:
+ common.LogError(c, "ping goroutine max duration reached")
return
}
}
})
}
+ // Scanner goroutine with improved error handling
+ wg.Add(1)
common.RelayCtxGo(ctx, func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
+ }
+ common.SafeSendBool(stopChan, true)
+ if common.DebugEnabled {
+ println("scanner goroutine exited")
+ }
+ }()
+
for scanner.Scan() {
+ // 检查是否需要停止
+ select {
+ case <-stopChan:
+ return
+ case <-ctx.Done():
+ return
+ case <-c.Request.Context().Done():
+ return
+ default:
+ }
+
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
@@ -119,11 +205,27 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
- writeMutex.Lock() // Lock before writing
- success := dataHandler(data)
- writeMutex.Unlock() // Unlock after writing
- if !success {
- break
+
+ // 使用超时机制防止写操作阻塞
+ done := make(chan bool, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- dataHandler(data)
+ }()
+
+ select {
+ case success := <-done:
+ if !success {
+ return
+ }
+ case <-time.After(10 * time.Second):
+ common.LogError(c, "data handler timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
+ return
}
}
}
@@ -133,17 +235,18 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
common.LogError(c, "scanner error: "+err.Error())
}
}
-
- common.SafeSendBool(stopChan, true)
})
+ // 主循环等待完成或超时
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
- common.SafeSendBool(stopChan, true)
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
+ case <-c.Request.Context().Done():
+ // 客户端断开连接
+ common.LogInfo(c, "client disconnected")
}
}
diff --git a/relay/relay-image.go b/relay/image_handler.go
similarity index 71%
rename from relay/relay-image.go
rename to relay/image_handler.go
index daed3d80..15a42e79 100644
--- a/relay/relay-image.go
+++ b/relay/image_handler.go
@@ -17,6 +17,8 @@ import (
"one-api/setting"
"strings"
+ "one-api/relay/constant"
+
"github.com/gin-gonic/gin"
)
@@ -41,16 +43,36 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
imageRequest.Quality = "standard"
}
}
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+
+ if info.ApiType == constant.APITypeVolcEngine {
+ watermark := formData.Has("watermark")
+ imageRequest.Watermark = &watermark
+ }
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
+
+ if imageRequest.Model == "" {
+ imageRequest.Model = "dall-e-3"
+ }
+
+ if strings.Contains(imageRequest.Size, "×") {
+ return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+ }
+
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
}
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
@@ -58,74 +80,24 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
- // N should between 1 and 10
- //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
- // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
- //}
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "auto"
+ }
+ }
+
+ if imageRequest.Prompt == "" {
+ return nil, errors.New("prompt is required")
+ }
+
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
}
}
- if imageRequest.Prompt == "" {
- return nil, errors.New("prompt is required")
- }
-
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-2"
- }
- if strings.Contains(imageRequest.Size, "×") {
- return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
-
- err := common.UnmarshalBodyReusable(c, imageRequest)
- if err != nil {
- return nil, err
- }
- if imageRequest.Prompt == "" {
- return nil, errors.New("prompt is required")
- }
- if strings.Contains(imageRequest.Size, "×") {
- return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-2"
- }
- // x.ai grok-2-image not support size, quality or style
- if imageRequest.Size == "empty" {
- imageRequest.Size = ""
- }
-
- // Not "256x256", "512x512", or "1024x1024"
- if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
- if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
- return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
- }
- } else if imageRequest.Model == "dall-e-3" {
- if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
- return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
- }
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- //if imageRequest.N != 1 {
- // return nil, errors.New("n must be 1")
- //}
- }
- // N should between 1 and 10
- //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
- // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
- //}
if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
@@ -137,7 +109,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
- relayInfo := relaycommon.GenRelayInfo(c)
+ relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
@@ -145,13 +117,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- imageRequest.Model = relayInfo.UpstreamModelName
-
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
@@ -197,7 +167,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
- quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+ quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -229,6 +199,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData)
}
+ if common.DebugEnabled {
+ println(fmt.Sprintf("image request body: %s", requestBody))
+ }
+
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 9d0a2077..ce4346b6 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -15,7 +15,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -174,17 +174,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
+ modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
- groupRatio := setting.GetGroupRatio(group)
+ groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
@@ -480,17 +480,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
+ modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
- groupRatio := setting.GetGroupRatio(group)
+ groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 8d5cd384..db8d0d3b 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -47,6 +47,20 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
if textRequest.Model == "" {
return nil, errors.New("model is required")
}
+ if textRequest.WebSearchOptions != nil {
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ validSizes := map[string]bool{
+ "high": true,
+ "medium": true,
+ "low": true,
+ }
+ if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
+ return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
+ }
+ } else {
+ textRequest.WebSearchOptions.SearchContextSize = "medium"
+ }
+ }
switch relayInfo.RelayMode {
case relayconstant.RelayModeCompletions:
if textRequest.Prompt == "" {
@@ -76,11 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
+
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
+ if textRequest.WebSearchOptions != nil {
+ c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
+ }
+
if setting.ShouldCheckPromptSensitive() {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -89,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- textRequest.Model = relayInfo.UpstreamModelName
-
// 获取 promptTokens,如果上下文中已经存在,则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -234,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
- promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
+ promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
+ promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
+ promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
@@ -334,6 +351,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
imageTokens := usage.PromptTokensDetails.ImageTokens
+ audioTokens := usage.PromptTokensDetails.AudioTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
@@ -342,13 +360,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
dImageTokens := decimal.NewFromInt(int64(imageTokens))
+ dAudioTokens := decimal.NewFromInt(int64(audioTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio)
@@ -370,9 +389,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s",
+ extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
}
+ } else if strings.HasSuffix(modelName, "search-preview") {
+ // search-preview 模型不支持 response api
+ searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
+ if searchContextSize == "" {
+ searchContextSize = "medium"
+ }
+ webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
+ dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
+ searchContextSize, dWebSearchQuota.String())
}
// file search tool 计费
var dFileSearchQuota decimal.Decimal
@@ -383,23 +413,43 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s",
+ extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
var quotaCalculateDecimal decimal.Decimal
- if !priceData.UsePrice {
- nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
- cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
- promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
- if imageTokens > 0 {
- nonImageTokens := dPromptTokens.Sub(dImageTokens)
- imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
- promptQuota = nonImageTokens.Add(imageTokensWithRatio)
+ var audioInputQuota decimal.Decimal
+ var audioInputPrice float64
+ if !priceData.UsePrice {
+ baseTokens := dPromptTokens
+ // 减去 cached tokens
+ var cachedTokensWithRatio decimal.Decimal
+ if !dCacheTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dCacheTokens)
+ cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
+ // 减去 image tokens
+ var imageTokensWithRatio decimal.Decimal
+ if !dImageTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dImageTokens)
+ imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
+ }
+
+ // 减去 Gemini audio tokens
+ if !dAudioTokens.IsZero() {
+ audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
+ if audioInputPrice > 0 {
+ // 重新计算 base tokens
+ baseTokens = baseTokens.Sub(dAudioTokens)
+ audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
+ }
+ }
+ promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
@@ -413,6 +463,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// 添加 responses tools call 调用的配额
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
+ // 添加 audio input 独立计费
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
@@ -457,16 +509,22 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
other["image_output"] = imageTokens
}
- if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
- if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
+ if !dWebSearchQuota.IsZero() {
+ if relayInfo.ResponsesUsageInfo != nil {
+ if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
+ other["web_search"] = true
+ other["web_search_call_count"] = webSearchTool.CallCount
+ other["web_search_price"] = webSearchPrice
+ }
+ } else if strings.HasSuffix(modelName, "search-preview") {
other["web_search"] = true
- other["web_search_call_count"] = webSearchTool.CallCount
+ other["web_search_call_count"] = 1
other["web_search_price"] = webSearchPrice
}
}
@@ -477,6 +535,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["file_search_price"] = fileSearchPrice
}
}
+ if !audioInputQuota.IsZero() {
+ other["audio_input_seperate_price"] = true
+ other["audio_input_token_count"] = audioTokens
+ other["audio_input_price"] = audioInputPrice
+ }
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 7bf0da9f..626bb7e4 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
+ "one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
// return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
+ case commonconstant.TaskPlatformKling:
+ return &kling.TaskAdaptor{}
}
return nil
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 26874ba6..245fd681 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -15,8 +15,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
)
/*
@@ -38,9 +37,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
+ if platform == constant.TaskPlatformKling {
+ modelName = relayInfo.OriginModelName
+ }
+ modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -49,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
@@ -137,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
relayInfo.ConsumeQuota = true
// insert task
- task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
+ task := model.InitTask(platform, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
+ task.Action = relayInfo.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -150,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+ relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -226,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
+func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+
+ originTask, exist, err := model.GetByTaskId(userId, taskId)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ return
+}
+
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
diff --git a/relay/relay_rerank.go b/relay/rerank_handler.go
similarity index 92%
rename from relay/relay_rerank.go
rename to relay/rerank_handler.go
index 6ca98de7..319811b8 100644
--- a/relay/relay_rerank.go
+++ b/relay/rerank_handler.go
@@ -14,12 +14,10 @@ import (
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
- token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
+ token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents {
- tkm, err := service.CountTokenInput(document, rerankRequest.Model)
- if err == nil {
- token += tkm
- }
+ tkm := service.CountTokenInput(document, rerankRequest.Model)
+ token += tkm
}
return token
}
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- rerankRequest.Model = relayInfo.UpstreamModelName
-
promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken
diff --git a/relay/relay-responses.go b/relay/responses_handler.go
similarity index 93%
rename from relay/relay-responses.go
rename to relay/responses_handler.go
index fd3ddb5a..e744e354 100644
--- a/relay/relay-responses.go
+++ b/relay/responses_handler.go
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
return sensitiveWords, err
}
-func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
- inputTokens, err := service.CountTokenInput(req.Input, req.Model)
+func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
+ inputTokens := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens
- return inputTokens, err
+ return inputTokens
}
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
}
}
- err = helper.ModelMappedHelper(c, relayInfo)
+ err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
- req.Model = relayInfo.UpstreamModelName
+
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
- promptTokens, err := getInputTokens(req, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
- }
+ promptTokens := getInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}
diff --git a/relay/websocket.go b/relay/websocket.go
index c815eb71..571f3a82 100644
--- a/relay/websocket.go
+++ b/relay/websocket.go
@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
- //relayInfo.UpstreamModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
- //if constant.ShouldCheckPromptSensitive() {
- // err = checkRequestSensitive(textRequest, relayInfo)
- // if err != nil {
- // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
- // }
- //}
-
- //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
- //// count messages token error 计算promptTokens错误
- //if err != nil {
- // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
- //}
- //
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
- // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
- //}
- modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- relayInfo.UsePrice = true
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
- userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ userQuota, priceData, "")
return nil
}
diff --git a/router/api-router.go b/router/api-router.go
index 1720ff57..badfa7bf 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/setup", controller.GetSetup)
apiRouter.POST("/setup", controller.PostSetup)
apiRouter.GET("/status", controller.GetStatus)
+ apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice)
@@ -35,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
+ apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
userRoute := apiRouter.Group("/user")
{
@@ -80,6 +82,13 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
+ optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
+ }
+ ratioSyncRoute := apiRouter.Group("/ratio_sync")
+ ratioSyncRoute.Use(middleware.RootAuth())
+ {
+ ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
+ ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
@@ -105,6 +114,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
+ channelRoute.GET("/tag/models", controller.GetTagModels)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())
@@ -124,6 +134,7 @@ func SetApiRouter(router *gin.Engine) {
redemptionRoute.GET("/:id", controller.GetRedemption)
redemptionRoute.POST("/", controller.AddRedemption)
redemptionRoute.PUT("/", controller.UpdateRedemption)
+ redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
}
logRoute := apiRouter.Group("/log")
diff --git a/router/main.go b/router/main.go
index b8ac4055..0d2bfdce 100644
--- a/router/main.go
+++ b/router/main.go
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
+ SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
diff --git a/router/relay-router.go b/router/relay-router.go
index 4cd84b41..aa7f27a8 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -11,6 +11,7 @@ import (
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
router.Use(middleware.DecompressRequestMiddleware())
+ router.Use(middleware.StatsMiddleware())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
@@ -79,6 +80,14 @@ func SetRelayRouter(router *gin.Engine) {
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
}
+ relayGeminiRouter := router.Group("/v1beta")
+ relayGeminiRouter.Use(middleware.TokenAuth())
+ relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
+ relayGeminiRouter.Use(middleware.Distribute())
+ {
+ // Gemini API 路径格式: /v1beta/models/{model_name}:{action}
+ relayGeminiRouter.POST("/models/*path", controller.Relay)
+ }
}
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
diff --git a/router/video-router.go b/router/video-router.go
new file mode 100644
index 00000000..7201c34a
--- /dev/null
+++ b/router/video-router.go
@@ -0,0 +1,17 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+
+ "github.com/gin-gonic/gin"
+)
+
+func SetVideoRouter(router *gin.Engine) {
+ videoV1Router := router.Group("/v1")
+ videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ videoV1Router.POST("/video/generations", controller.RelayTask)
+ videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+ }
+}
diff --git a/service/audio.go b/service/audio.go
index d558e96f..c4b6f01b 100644
--- a/service/audio.go
+++ b/service/audio.go
@@ -3,6 +3,7 @@ package service
import (
"encoding/base64"
"fmt"
+ "strings"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
@@ -29,3 +30,19 @@ func parseAudio(audioBase64 string, format string) (duration float64, err error)
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}
+
+func DecodeBase64AudioData(audioBase64 string) (string, error) {
+ // 检查并移除 data:audio/xxx;base64, 前缀
+ idx := strings.Index(audioBase64, ",")
+ if idx != -1 {
+ audioBase64 = audioBase64[idx+1:]
+ }
+
+ // 解码 Base64 数据
+ _, err := base64.StdEncoding.DecodeString(audioBase64)
+ if err != nil {
+ return "", fmt.Errorf("base64 decode error: %v", err)
+ }
+
+ return audioBase64, nil
+}
diff --git a/service/channel.go b/service/channel.go
index e3a76af4..746e9a34 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
case "billing_not_active":
return true
+ case "pre_consume_token_quota_failed":
+ return true
}
switch err.Error.Type {
case "insufficient_quota":
diff --git a/service/convert.go b/service/convert.go
index cc462b40..7a9e8403 100644
--- a/service/convert.go
+++ b/service/convert.go
@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"one-api/dto"
+ "one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"strings"
)
@@ -18,10 +19,24 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
Stream: claudeRequest.Stream,
}
- if claudeRequest.Thinking != nil {
- if strings.HasSuffix(info.OriginModelName, "-thinking") &&
- !strings.HasSuffix(claudeRequest.Model, "-thinking") {
- openAIRequest.Model = openAIRequest.Model + "-thinking"
+ isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
+
+ if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
+ if isOpenRouter {
+ reasoning := openrouter.RequestReasoning{
+ MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
+ }
+ reasoningJSON, err := json.Marshal(reasoning)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal reasoning: %w", err)
+ }
+ openAIRequest.Reasoning = reasoningJSON
+ } else {
+ thinkingSuffix := "-thinking"
+ if strings.HasSuffix(info.OriginModelName, thinkingSuffix) &&
+ !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) {
+ openAIRequest.Model = openAIRequest.Model + thinkingSuffix
+ }
}
}
@@ -62,16 +77,30 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
- systemStr := ""
openAIMessage := dto.Message{
Role: "system",
}
- for _, system := range systems {
- if system.Text != nil {
- systemStr += *system.Text
+ isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude")
+ if isOpenRouterClaude {
+ systemMediaMessages := make([]dto.MediaContent, 0, len(systems))
+ for _, system := range systems {
+ message := dto.MediaContent{
+ Type: "text",
+ Text: system.GetText(),
+ CacheControl: system.CacheControl,
+ }
+ systemMediaMessages = append(systemMediaMessages, message)
}
+ openAIMessage.SetMediaContent(systemMediaMessages)
+ } else {
+ systemStr := ""
+ for _, system := range systems {
+ if system.Text != nil {
+ systemStr += *system.Text
+ }
+ }
+ openAIMessage.SetStringContent(systemStr)
}
- openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage)
}
}
@@ -97,8 +126,9 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
switch mediaMsg.Type {
case "text":
message := dto.MediaContent{
- Type: "text",
- Text: mediaMsg.GetText(),
+ Type: "text",
+ Text: mediaMsg.GetText(),
+ CacheControl: mediaMsg.CacheControl,
}
mediaMessages = append(mediaMessages, message)
case "image":
diff --git a/service/error.go b/service/error.go
index 1bf5992b..f3d8a17d 100644
--- a/service/error.go
+++ b/service/error.go
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
+ if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
}
openAIError := dto.OpenAIError{
Message: text,
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
+ if !strings.HasPrefix(lowerText, "get file base64 from url") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
}
claudeError := dto.ClaudeError{
Message: text,
diff --git a/service/file_decoder.go b/service/file_decoder.go
index bbb188f8..c1d4fb0c 100644
--- a/service/file_decoder.go
+++ b/service/file_decoder.go
@@ -4,8 +4,10 @@ import (
"encoding/base64"
"fmt"
"io"
+ "one-api/common"
"one-api/constant"
"one-api/dto"
+ "strings"
)
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
+ mimeType := resp.Header.Get("Content-Type")
+ if len(strings.Split(mimeType, ";")) > 1 {
+ // If Content-Type has parameters, take the first part
+ mimeType = strings.Split(mimeType, ";")[0]
+ }
+ if mimeType == "application/octet-stream" {
+ if common.DebugEnabled {
+ println("MIME type is application/octet-stream, trying to guess from URL or filename")
+ }
+ // try to guess the MIME type from the url last segment
+ urlParts := strings.Split(url, "/")
+ if len(urlParts) > 0 {
+ lastSegment := urlParts[len(urlParts)-1]
+ if strings.Contains(lastSegment, ".") {
+ // Extract the file extension
+ filename := strings.Split(lastSegment, ".")
+ if len(filename) > 1 {
+ ext := strings.ToLower(filename[len(filename)-1])
+ // Guess MIME type based on file extension
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ }
+ } else {
+ // try to guess the MIME type from the file extension
+ fileName := resp.Header.Get("Content-Disposition")
+ if fileName != "" {
+ // Extract the filename from the Content-Disposition header
+ parts := strings.Split(fileName, ";")
+ for _, part := range parts {
+ if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
+ fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+ // Remove quotes if present
+ if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
+ fileName = fileName[1 : len(fileName)-1]
+ }
+ // Guess MIME type based on file extension
+ if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ break
+ }
+ }
+ }
+ }
+ }
+
return &dto.LocalFileData{
Base64Data: base64Data,
- MimeType: resp.Header.Get("Content-Type"),
+ MimeType: mimeType,
Size: int64(len(fileBytes)),
}, nil
}
+
+func GetMimeTypeByExtension(ext string) string {
+ // Convert to lowercase for case-insensitive comparison
+ ext = strings.ToLower(ext)
+ switch ext {
+ // Text files
+ case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
+ return "text/plain"
+
+ // Image files
+ case "jpg", "jpeg":
+ return "image/jpeg"
+ case "png":
+ return "image/png"
+ case "gif":
+ return "image/gif"
+
+ // Audio files
+ case "mp3":
+ return "audio/mp3"
+ case "wav":
+ return "audio/wav"
+ case "mpeg":
+ return "audio/mpeg"
+
+ // Video files
+ case "mp4":
+ return "video/mp4"
+ case "wmv":
+ return "video/wmv"
+ case "flv":
+ return "video/flv"
+ case "mov":
+ return "video/mov"
+ case "mpg":
+ return "video/mpg"
+ case "avi":
+ return "video/avi"
+ case "mpegps":
+ return "video/mpegps"
+
+ // Document files
+ case "pdf":
+ return "application/pdf"
+
+ default:
+ return "application/octet-stream" // Default for unknown types
+ }
+}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 75457b97..1edc9073 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -8,7 +8,7 @@ import (
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
+ cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
@@ -16,6 +16,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["cache_tokens"] = cacheTokens
other["cache_ratio"] = cacheRatio
other["model_price"] = modelPrice
+ other["user_group_ratio"] = userGroupRatio
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
@@ -30,8 +31,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
return other
}
-func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
@@ -42,8 +43,8 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
return info
}
-func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
@@ -55,8 +56,8 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
+ cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
info["claude"] = true
info["cache_creation_tokens"] = cacheCreationTokens
info["cache_creation_ratio"] = cacheCreationRatio
diff --git a/service/quota.go b/service/quota.go
index 0d11b4a0..973deba7 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
+ "log"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -10,7 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"strings"
"time"
@@ -45,9 +46,9 @@ func calculateAudioQuota(info QuotaInfo) int {
return int(quota.IntPart())
}
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
@@ -93,8 +94,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- modelRatio, _ := operation_setting.GetModelRatio(modelName)
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
+ modelRatio, _ := ratio_setting.GetModelRatio(modelName)
+
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
+ if ok {
+ actualGroupRatio = userGroupRatio
+ }
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -108,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
- GroupRatio: groupRatio,
+ GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -130,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -141,9 +154,14 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
+
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -189,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -205,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
-
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -256,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -272,12 +289,12 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
@@ -333,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
diff --git a/service/token_counter.go b/service/token_counter.go
index d63b54ad..53c6c2fa 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/tiktoken-go/tokenizer"
+ "github.com/tiktoken-go/tokenizer/codec"
"image"
"log"
"math"
@@ -11,78 +13,63 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
- "one-api/setting/operation_setting"
"strings"
+ "sync"
"unicode/utf8"
-
- "github.com/pkoukk/tiktoken-go"
)
// tokenEncoderMap won't grow after initialization
-var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
-var defaultTokenEncoder *tiktoken.Tiktoken
-var o200kTokenEncoder *tiktoken.Tiktoken
+var defaultTokenEncoder tokenizer.Codec
+
+// tokenEncoderMap is used to store token encoders for different models
+var tokenEncoderMap = make(map[string]tokenizer.Codec)
+
+// tokenEncoderMutex protects tokenEncoderMap for concurrent access
+var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
- cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
- }
- defaultTokenEncoder = cl100TokenEncoder
- o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
- }
- for model, _ := range operation_setting.GetDefaultModelRatioMap() {
- if strings.HasPrefix(model, "gpt-3.5") {
- tokenEncoderMap[model] = cl100TokenEncoder
- } else if strings.HasPrefix(model, "gpt-4") {
- if strings.HasPrefix(model, "gpt-4o") {
- tokenEncoderMap[model] = o200kTokenEncoder
- } else {
- tokenEncoderMap[model] = defaultTokenEncoder
- }
- } else if strings.HasPrefix(model, "o") {
- tokenEncoderMap[model] = o200kTokenEncoder
- } else {
- tokenEncoderMap[model] = defaultTokenEncoder
- }
- }
+ defaultTokenEncoder = codec.NewCl100kBase()
common.SysLog("token encoders initialized")
}
-func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
- if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
- return o200kTokenEncoder
+func getTokenEncoder(model string) tokenizer.Codec {
+ // First, try to get the encoder from cache with read lock
+ tokenEncoderMutex.RLock()
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ tokenEncoderMutex.RUnlock()
+ return encoder
}
- return defaultTokenEncoder
+ tokenEncoderMutex.RUnlock()
+
+ // If not in cache, create new encoder with write lock
+ tokenEncoderMutex.Lock()
+ defer tokenEncoderMutex.Unlock()
+
+ // Double-check if another goroutine already created the encoder
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ return encoder
+ }
+
+ // Create new encoder
+ modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
+ if err != nil {
+ // Cache the default encoder for this model to avoid repeated failures
+ tokenEncoderMap[model] = defaultTokenEncoder
+ return defaultTokenEncoder
+ }
+
+ // Cache the new encoder
+ tokenEncoderMap[model] = modelCodec
+ return modelCodec
}
-func getTokenEncoder(model string) *tiktoken.Tiktoken {
- tokenEncoder, ok := tokenEncoderMap[model]
- if ok && tokenEncoder != nil {
- return tokenEncoder
- }
- // 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
- if ok {
- tokenEncoder, err := tiktoken.EncodingForModel(model)
- if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
- tokenEncoder = getModelDefaultTokenEncoder(model)
- }
- tokenEncoderMap[model] = tokenEncoder
- return tokenEncoder
- }
- // 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
- return getModelDefaultTokenEncoder(model)
-}
-
-func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
if text == "" {
return 0
}
- return len(tokenEncoder.Encode(text, nil, nil))
+ tkm, _ := tokenEncoder.Count(text)
+ return tkm
}
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
@@ -184,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
- toolTokens, err := CountTokenInput(countStr, request.Model)
+ toolTokens := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
@@ -207,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
- systemTokens, err := CountTokenInput(request.System, model)
+ systemTokens := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
@@ -261,12 +248,16 @@ func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream
//}
tokenNum += 1000
case "tool_use":
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
- inputJSON, _ := json.Marshal(mediaMessage.Input)
- tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
+ if mediaMessage.Input != nil {
+ tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
+ inputJSON, _ := json.Marshal(mediaMessage.Input)
+ tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
+ }
case "tool_result":
- contentJSON, _ := json.Marshal(mediaMessage.Content)
- tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
+ if mediaMessage.Content != nil {
+ contentJSON, _ := json.Marshal(mediaMessage.Content)
+ tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
+ }
}
}
}
@@ -305,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
- msgTokens, err := CountTextToken(request.Session.Instructions, model)
- if err != nil {
- return 0, 0, err
- }
+ msgTokens := CountTextToken(request.Session.Instructions, model)
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
@@ -320,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
- tkm, err := CountTextToken(request.Delta, model)
- if err != nil {
- return 0, 0, fmt.Errorf("error counting text token: %v", err)
- }
+ tkm := CountTextToken(request.Delta, model)
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
@@ -338,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
- tokens, err := CountTextToken(content.Text, model)
- if err != nil {
- return 0, 0, err
- }
+ tokens := CountTextToken(content.Text, model)
textToken += tokens
}
}
@@ -352,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
- toolTokens, err := CountTokenInput(tool, model)
- if err != nil {
- return 0, 0, err
- }
+ toolTokens := CountTokenInput(tool, model)
textToken += 8
textToken += toolTokens
}
@@ -386,7 +365,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
- if len(message.Content) > 0 {
+ if message.Content != nil {
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
@@ -418,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
return tokenNum, nil
}
-func CountTokenInput(input any, model string) (int, error) {
+func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTextToken(v, model)
@@ -441,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
- tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
+ tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
- tkm, _ := CountTokenInput(tool.Function.Name, model)
+ tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm
- tkm, _ = CountTokenInput(tool.Function.Arguments, model)
+ tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
@@ -455,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
-func CountTTSToken(text string, model string) (int, error) {
+func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text), nil
+ return utf8.RuneCountInString(text)
} else {
return CountTextToken(text, model)
}
@@ -492,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
//}
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTextToken(text string, model string) (int, error) {
- var err error
+func CountTextToken(text string, model string) int {
+ if text == "" {
+ return 0
+ }
tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text), err
+ return getTokenNum(tokenEncoder, text)
}
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
index c52e1e15..ca9c0830 100644
--- a/service/usage_helpr.go
+++ b/service/usage_helpr.go
@@ -16,13 +16,13 @@ import (
// return 0, errors.New("unknown relay mode")
//}
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
- ctkm, err := CountTextToken(responseText, modeName)
+ ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage, err
+ return usage
}
func ValidUsage(usage *dto.Usage) bool {
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/console_setting/config.go b/setting/console_setting/config.go
new file mode 100644
index 00000000..6327e558
--- /dev/null
+++ b/setting/console_setting/config.go
@@ -0,0 +1,39 @@
+package console_setting
+
+import "one-api/setting/config"
+
+type ConsoleSetting struct {
+ ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
+ UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
+ Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
+ FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
+ ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
+ UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
+ AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
+ FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
+}
+
+// 默认配置
+var defaultConsoleSetting = ConsoleSetting{
+ ApiInfo: "",
+ UptimeKumaGroups: "",
+ Announcements: "",
+ FAQ: "",
+ ApiInfoEnabled: true,
+ UptimeKumaEnabled: true,
+ AnnouncementsEnabled: true,
+ FAQEnabled: true,
+}
+
+// 全局实例
+var consoleSetting = defaultConsoleSetting
+
+func init() {
+ // 注册到全局配置管理器,键名为 console_setting
+ config.GlobalConfig.Register("console_setting", &consoleSetting)
+}
+
+// GetConsoleSetting 获取 ConsoleSetting 配置实例
+func GetConsoleSetting() *ConsoleSetting {
+ return &consoleSetting
+}
\ No newline at end of file
diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go
new file mode 100644
index 00000000..fda6453d
--- /dev/null
+++ b/setting/console_setting/validation.go
@@ -0,0 +1,304 @@
+package console_setting
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+ "sort"
+)
+
+var (
+ urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
+ dangerousChars = []string{"