diff --git a/README.md b/README.md
index e9d1c154..6ba3574c 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,9 @@
+
+
+
@@ -180,7 +183,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
其他基于New API的项目:
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
-- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
## 帮助支持
diff --git a/common/database.go b/common/database.go
index 3c0a944b..9cbaf46a 100644
--- a/common/database.go
+++ b/common/database.go
@@ -1,7 +1,14 @@
package common
+const (
+ DatabaseTypeMySQL = "mysql"
+ DatabaseTypeSQLite = "sqlite"
+ DatabaseTypePostgreSQL = "postgres"
+)
+
var UsingSQLite = false
var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false
var UsingClickHouse = false
diff --git a/common/redis.go b/common/redis.go
index ba35331a..1efc217f 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
- txn.Expire(ctx, key, expiration)
+
+ // 只有在 expiration 大于 0 时才设置过期时间
+ if expiration > 0 {
+ txn.Expire(ctx, key, expiration)
+ }
_, err := txn.Exec(ctx)
if err != nil {
diff --git a/common/utils.go b/common/utils.go
index 587de537..d9db67d0 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -249,13 +249,38 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
+ durationStr := string(bytes.TrimSpace(output))
+ if durationStr == "N/A" {
+ // Create a temporary output file name
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to create temporary file")
+ }
+ tmpName := tmpFp.Name()
+ // Close immediately so ffmpeg can open the file on Windows.
+ _ = tmpFp.Close()
+ defer os.Remove(tmpName)
- return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+ // ffmpeg -y -i filename -vcodec copy -acodec copy
+ ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+ if err := ffmpegCmd.Run(); err != nil {
+ return 0, errors.Wrap(err, "failed to run ffmpeg")
+ }
+
+ // Recalculate the duration of the new file
+ c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+ output, err := c.Output()
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+ }
+ durationStr = string(bytes.TrimSpace(output))
+ }
+ return strconv.ParseFloat(durationStr, 64)
}
diff --git a/constant/cache_key.go b/constant/cache_key.go
index 27cb3b75..daedfd40 100644
--- a/constant/cache_key.go
+++ b/constant/cache_key.go
@@ -2,12 +2,10 @@ package constant
import "one-api/common"
-var (
- TokenCacheSeconds = common.SyncFrequency
- UserId2GroupCacheSeconds = common.SyncFrequency
- UserId2QuotaCacheSeconds = common.SyncFrequency
- UserId2StatusCacheSeconds = common.SyncFrequency
-)
+// 使用函数来避免初始化顺序带来的赋值问题
+func RedisKeyCacheSeconds() int {
+ return common.SyncFrequency
+}
// Cache keys
const (
diff --git a/constant/user_setting.go b/constant/user_setting.go
index 055884f7..7e79035e 100644
--- a/constant/user_setting.go
+++ b/constant/user_setting.go
@@ -7,6 +7,7 @@ var (
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
+ UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
)
var (
diff --git a/controller/channel-test.go b/controller/channel-test.go
index f9c7bf7b..d162d8cf 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
- other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
- usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -312,7 +312,7 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
-
+
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
diff --git a/controller/channel.go b/controller/channel.go
index a4ef87c3..1cfb7906 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -43,22 +43,23 @@ type OpenAIModelsResponse struct {
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
- if pageSize < 0 {
+ if pageSize < 1 {
pageSize = common.ItemsPerPage
}
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+
+ var total int64
+
if enableTagMode {
- tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
+ // tag 分页:先分页 tag,再取各 tag 下 channels
+ tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
@@ -69,21 +70,27 @@ func GetAllChannels(c *gin.Context) {
}
}
}
+ // 计算 tag 总数用于分页
+ total, _ = model.CountAllTags()
} else {
- channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
+ channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
+ total, _ = model.CountAllChannels()
}
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": channelData,
+ "data": gin.H{
+ "items": channelData,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
return
}
diff --git a/controller/console_migrate.go b/controller/console_migrate.go
new file mode 100644
index 00000000..d25f199b
--- /dev/null
+++ b/controller/console_migrate.go
@@ -0,0 +1,103 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+ "encoding/json"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+ // 读取全部 option
+ opts, err := model.AllOption()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // 建立 map
+ valMap := map[string]string{}
+ for _, o := range opts {
+ valMap[o.Key] = o.Value
+ }
+
+ // 处理 APIInfo
+ if v := valMap["ApiInfo"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ if len(arr) > 50 {
+ arr = arr[:50]
+ }
+ bytes, _ := json.Marshal(arr)
+ model.UpdateOption("console_setting.api_info", string(bytes))
+ }
+ model.UpdateOption("ApiInfo", "")
+ }
+ // Announcements 直接搬
+ if v := valMap["Announcements"]; v != "" {
+ model.UpdateOption("console_setting.announcements", v)
+ model.UpdateOption("Announcements", "")
+ }
+ // FAQ 转换
+ if v := valMap["FAQ"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ out := []map[string]interface{}{}
+ for _, item := range arr {
+ q, _ := item["question"].(string)
+ if q == "" {
+ q, _ = item["title"].(string)
+ }
+ a, _ := item["answer"].(string)
+ if a == "" {
+ a, _ = item["content"].(string)
+ }
+ if q != "" && a != "" {
+ out = append(out, map[string]interface{}{"question": q, "answer": a})
+ }
+ }
+ if len(out) > 50 {
+ out = out[:50]
+ }
+ bytes, _ := json.Marshal(out)
+ model.UpdateOption("console_setting.faq", string(bytes))
+ }
+ model.UpdateOption("FAQ", "")
+ }
+ // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+ url := valMap["UptimeKumaUrl"]
+ slug := valMap["UptimeKumaSlug"]
+ if url != "" && slug != "" {
+ // 仅当同时存在 URL 与 Slug 时才进行迁移
+ groups := []map[string]interface{}{
+ {
+ "id": 1,
+ "categoryName": "old",
+ "url": url,
+ "slug": slug,
+ "description": "",
+ },
+ }
+ bytes, _ := json.Marshal(groups)
+ model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+ }
+ // 清空旧键内容
+ if url != "" {
+ model.UpdateOption("UptimeKumaUrl", "")
+ }
+ if slug != "" {
+ model.UpdateOption("UptimeKumaSlug", "")
+ }
+
+ // 删除旧键记录
+ oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+ model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+ // 重新加载 OptionMap
+ model.InitOptionMap()
+ common.SysLog("console setting migrated")
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}
\ No newline at end of file
diff --git a/controller/group.go b/controller/group.go
index 2c725a4d..632b6cd5 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
@@ -34,6 +35,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 21027d8f..56bdcb80 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -7,7 +7,6 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"io"
- "log"
"net/http"
"one-api/common"
"one-api/dto"
@@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
// 解析其他查询参数
@@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
+ total := model.CountAllTasks(queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
userId := c.GetInt("id")
- log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
@@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+ total := model.CountAllUserTask(userId, queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
diff --git a/controller/misc.go b/controller/misc.go
index 8fa8e8f6..1caaf640 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -9,6 +9,7 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
@@ -37,52 +38,72 @@ func TestStatus(c *gin.Context) {
func GetStatus(c *gin.Context) {
+ cs := console_setting.GetConsoleSetting()
+
+ data := gin.H{
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "min_topup": setting.MinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
+
+ // 面板启用开关
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
+
+ "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
+ "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
+ "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+ "setup": constant.Setup,
+ }
+
+ // 根据启用状态注入可选内容
+ if cs.ApiInfoEnabled {
+ data["api_info"] = console_setting.GetApiInfo()
+ }
+ if cs.AnnouncementsEnabled {
+ data["announcements"] = console_setting.GetAnnouncements()
+ }
+ if cs.FAQEnabled {
+ data["faq"] = console_setting.GetFAQ()
+ }
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "min_topup": setting.MinTopUp,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
- "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
- "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
- "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
- "setup": constant.Setup,
- "api_info": setting.GetApiInfo(),
- "announcements": setting.GetAnnouncements(),
- "faq": setting.GetFAQ(),
- },
+ "data": data,
})
return
}
diff --git a/controller/model.go b/controller/model.go
index df7e59a6..134217a3 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -2,7 +2,6 @@ package controller
import (
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +14,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -179,7 +181,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
- models := model.GetGroupModels(group)
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupModels(group)
+ }
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
diff --git a/controller/option.go b/controller/option.go
index b52012fd..79ba2ffe 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -6,6 +6,7 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/system_setting"
"strings"
@@ -119,8 +120,8 @@ func UpdateOption(c *gin.Context) {
})
return
}
- case "ApiInfo":
- err = setting.ValidateApiInfo(option.Value)
+ case "console_setting.api_info":
+ err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -128,8 +129,8 @@ func UpdateOption(c *gin.Context) {
})
return
}
- case "Announcements":
- err = setting.ValidateConsoleSettings(option.Value, "Announcements")
+ case "console_setting.announcements":
+ err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -137,8 +138,17 @@ func UpdateOption(c *gin.Context) {
})
return
}
- case "FAQ":
- err = setting.ValidateConsoleSettings(option.Value, "FAQ")
+ case "console_setting.faq":
+ err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.uptime_kuma_groups":
+ err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
diff --git a/controller/playground.go b/controller/playground.go
index a2b54790..37a5c7b0 100644
--- a/controller/playground.go
+++ b/controller/playground.go
@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
+
+ "github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,9 +58,9 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
- channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
+ channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
diff --git a/controller/pricing.go b/controller/pricing.go
index 1cbfe731..e6a3e57f 100644
--- a/controller/pricing.go
+++ b/controller/pricing.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
+
+ "github.com/gin-gonic/gin"
)
func GetPricing(c *gin.Context) {
@@ -20,6 +21,12 @@ func GetPricing(c *gin.Context) {
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
+ for g := range groupRatio {
+ ratio, ok := setting.GetGroupGroupRatio(group, g)
+ if ok {
+ groupRatio[g] = ratio
+ }
+ }
}
}
diff --git a/controller/redemption.go b/controller/redemption.go
index a7e09a8a..50620597 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -5,6 +5,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
+ "errors"
"github.com/gin-gonic/gin"
)
@@ -126,6 +127,10 @@ func AddRedemption(c *gin.Context) {
})
return
}
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
@@ -135,6 +140,7 @@ func AddRedemption(c *gin.Context) {
Key: key,
CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota,
+ ExpiredTime: redemption.ExpiredTime,
}
err = cleanRedemption.Insert()
if err != nil {
@@ -191,12 +197,18 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
- if statusOnly != "" {
- cleanRedemption.Status = redemption.Status
- } else {
+ if statusOnly == "" {
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
// If you add more fields, please also update redemption.Update()
cleanRedemption.Name = redemption.Name
cleanRedemption.Quota = redemption.Quota
+ cleanRedemption.ExpiredTime = redemption.ExpiredTime
+ }
+ if statusOnly != "" {
+ cleanRedemption.Status = redemption.Status
}
err = cleanRedemption.Update()
if err != nil {
@@ -213,3 +225,27 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
+
+func DeleteInvalidRedemption(c *gin.Context) {
+ rows, err := model.DeleteInvalidRedemptions()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": rows,
+ })
+ return
+}
+
+func validateExpiredTime(expired int64) error {
+ if expired != 0 && expired < common.GetTimestamp() {
+ return errors.New("过期时间不能早于当前时间")
+ }
+ return nil
+}
diff --git a/controller/relay.go b/controller/relay.go
index 1a875dbc..c1c45114 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
diff --git a/controller/setup.go b/controller/setup.go
index 0a13bcf9..8943a1a0 100644
--- a/controller/setup.go
+++ b/controller/setup.go
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
// If root doesn't exist, validate and create admin account
if !rootExists {
+ // Validate username length: max 12 characters to align with model.User validation
+ if len(req.Username) > 12 {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "用户名长度不能超过12个字符",
+ })
+ return
+ }
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
diff --git a/controller/task.go b/controller/task.go
index 65f79ead..34e14f3f 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -224,9 +224,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
+ }
+
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
@@ -237,24 +242,32 @@ func GetAllTask(c *gin.Context) {
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
+ ChannelID: c.Query("channel_id"),
}
- logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
+ items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
+ total := model.TaskCountAllTasks(queryParams)
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
+ }
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
+ if pageSize <= 0 {
+ pageSize = common.ItemsPerPage
}
userId := c.GetInt("id")
@@ -271,14 +284,17 @@ func GetUserTask(c *gin.Context) {
EndTimestamp: endTimestamp,
}
- logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
+ items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+ total := model.TaskCountAllUserTask(userId, queryParams)
c.JSON(200, gin.H{
"success": true,
"message": "",
- "data": logs,
+ "data": gin.H{
+ "items": items,
+ "total": total,
+ "page": p,
+ "page_size": pageSize,
+ },
})
}
diff --git a/controller/token.go b/controller/token.go
index a8803279..c57552c0 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id")
p, _ := strconv.Atoi(c.Query("p"))
size, _ := strconv.Atoi(c.Query("size"))
- if p < 0 {
- p = 0
+ if p < 1 {
+ p = 1
}
if size <= 0 {
size = common.ItemsPerPage
} else if size > 100 {
size = 100
}
- tokens, err := model.GetAllUserTokens(userId, p*size, size)
+ tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) {
})
return
}
+ // Get total count for pagination
+ total, _ := model.CountUserTokens(userId)
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": tokens,
+ "data": gin.H{
+ "items": tokens,
+ "total": total,
+ "page": p,
+ "page_size": size,
+ },
})
return
}
diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go
index 6ceaa1f3..05d6297e 100644
--- a/controller/uptime_kuma.go
+++ b/controller/uptime_kuma.go
@@ -4,9 +4,9 @@ import (
"context"
"encoding/json"
"errors"
- "fmt"
"net/http"
- "one-api/common"
+ "one-api/setting/console_setting"
+ "strconv"
"strings"
"time"
@@ -14,45 +14,25 @@ import (
"golang.org/x/sync/errgroup"
)
-type UptimeKumaMonitor struct {
- ID int `json:"id"`
- Name string `json:"name"`
- Type string `json:"type"`
-}
+const (
+ requestTimeout = 30 * time.Second
+ httpTimeout = 10 * time.Second
+ uptimeKeySuffix = "_24"
+ apiStatusPath = "/api/status-page/"
+ apiHeartbeatPath = "/api/status-page/heartbeat/"
+)
-type UptimeKumaGroup struct {
- ID int `json:"id"`
- Name string `json:"name"`
- Weight int `json:"weight"`
- MonitorList []UptimeKumaMonitor `json:"monitorList"`
-}
-
-type UptimeKumaHeartbeat struct {
- Status int `json:"status"`
- Time string `json:"time"`
- Msg string `json:"msg"`
- Ping *float64 `json:"ping"`
-}
-
-type UptimeKumaStatusResponse struct {
- PublicGroupList []UptimeKumaGroup `json:"publicGroupList"`
-}
-
-type UptimeKumaHeartbeatResponse struct {
- HeartbeatList map[string][]UptimeKumaHeartbeat `json:"heartbeatList"`
- UptimeList map[string]float64 `json:"uptimeList"`
-}
-
-type MonitorStatus struct {
+type Monitor struct {
Name string `json:"name"`
Uptime float64 `json:"uptime"`
Status int `json:"status"`
+ Group string `json:"group,omitempty"`
}
-var (
- ErrUpstreamNon200 = errors.New("upstream non-200")
- ErrTimeout = errors.New("context deadline exceeded")
-)
+type UptimeGroupResult struct {
+ CategoryName string `json:"categoryName"`
+ Monitors []Monitor `json:"monitors"`
+}
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
@@ -62,108 +42,113 @@ func getAndDecode(ctx context.Context, client *http.Client, url string, dest int
resp, err := client.Do(req)
if err != nil {
- if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
- return ErrTimeout
- }
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- return ErrUpstreamNon200
+ return errors.New("non-200 status")
}
return json.NewDecoder(resp.Body).Decode(dest)
}
-func GetUptimeKumaStatus(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- uptimeKumaUrl := common.OptionMap["UptimeKumaUrl"]
- slug := common.OptionMap["UptimeKumaSlug"]
- common.OptionMapRWMutex.RUnlock()
-
- if uptimeKumaUrl == "" || slug == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": []MonitorStatus{},
- })
- return
+func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
+ url, _ := groupConfig["url"].(string)
+ slug, _ := groupConfig["slug"].(string)
+ categoryName, _ := groupConfig["categoryName"].(string)
+
+ result := UptimeGroupResult{
+ CategoryName: categoryName,
+ Monitors: []Monitor{},
+ }
+
+ if url == "" || slug == "" {
+ return result
}
- uptimeKumaUrl = strings.TrimSuffix(uptimeKumaUrl, "/")
-
- ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
- defer cancel()
-
- client := &http.Client{}
-
- statusPageUrl := fmt.Sprintf("%s/api/status-page/%s", uptimeKumaUrl, slug)
- heartbeatUrl := fmt.Sprintf("%s/api/status-page/heartbeat/%s", uptimeKumaUrl, slug)
-
- var (
- statusData UptimeKumaStatusResponse
- heartbeatData UptimeKumaHeartbeatResponse
- )
+ baseURL := strings.TrimSuffix(url, "/")
+
+ var statusData struct {
+ PublicGroupList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ MonitorList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ } `json:"monitorList"`
+ } `json:"publicGroupList"`
+ }
+
+ var heartbeatData struct {
+ HeartbeatList map[string][]struct {
+ Status int `json:"status"`
+ } `json:"heartbeatList"`
+ UptimeList map[string]float64 `json:"uptimeList"`
+ }
g, gCtx := errgroup.WithContext(ctx)
-
- g.Go(func() error {
- return getAndDecode(gCtx, client, statusPageUrl, &statusData)
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
+ })
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
})
- g.Go(func() error {
- return getAndDecode(gCtx, client, heartbeatUrl, &heartbeatData)
- })
+ if g.Wait() != nil {
+ return result
+ }
- if err := g.Wait(); err != nil {
- switch err {
- case ErrUpstreamNon200:
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": "上游接口出现问题",
- })
- case ErrTimeout:
- c.JSON(http.StatusRequestTimeout, gin.H{
- "success": false,
- "message": "请求上游接口超时",
- })
- default:
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ for _, pg := range statusData.PublicGroupList {
+ if len(pg.MonitorList) == 0 {
+ continue
}
+
+ for _, m := range pg.MonitorList {
+ monitor := Monitor{
+ Name: m.Name,
+ Group: pg.Name,
+ }
+
+ monitorID := strconv.Itoa(m.ID)
+
+ if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
+ monitor.Uptime = uptime
+ }
+
+ if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
+ monitor.Status = heartbeats[0].Status
+ }
+
+ result.Monitors = append(result.Monitors, monitor)
+ }
+ }
+
+ return result
+}
+
+func GetUptimeKumaStatus(c *gin.Context) {
+ groups := console_setting.GetUptimeKumaGroups()
+ if len(groups) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
return
}
- var monitors []MonitorStatus
- for _, group := range statusData.PublicGroupList {
- for _, monitor := range group.MonitorList {
- monitorStatus := MonitorStatus{
- Name: monitor.Name,
- Uptime: 0.0,
- Status: 0,
- }
+ ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
+ defer cancel()
- uptimeKey := fmt.Sprintf("%d_24", monitor.ID)
- if uptime, exists := heartbeatData.UptimeList[uptimeKey]; exists {
- monitorStatus.Uptime = uptime
- }
-
- heartbeatKey := fmt.Sprintf("%d", monitor.ID)
- if heartbeats, exists := heartbeatData.HeartbeatList[heartbeatKey]; exists && len(heartbeats) > 0 {
- latestHeartbeat := heartbeats[0]
- monitorStatus.Status = latestHeartbeat.Status
- }
-
- monitors = append(monitors, monitorStatus)
- }
+ client := &http.Client{Timeout: httpTimeout}
+ results := make([]UptimeGroupResult, len(groups))
+
+ g, gCtx := errgroup.WithContext(ctx)
+ for i, group := range groups {
+ i, group := i, group
+ g.Go(func() error {
+ results[i] = fetchGroupData(gCtx, client, group)
+ return nil
+ })
}
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": monitors,
- })
+
+ g.Wait()
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
}
\ No newline at end of file
diff --git a/controller/user.go b/controller/user.go
index fd53e743..e8ce3c3d 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -459,6 +462,9 @@ func GetSelf(c *gin.Context) {
})
return
}
+ // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
+ user.Remark = ""
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -943,6 +949,7 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
+ RecordIpLog bool `json:"record_ip_log"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -1019,6 +1026,7 @@ func UpdateUserSetting(c *gin.Context) {
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
+ constant.UserSettingRecordIpLog: req.RecordIpLog,
}
// 如果是webhook类型,添加webhook相关设置
diff --git a/dto/claude.go b/dto/claude.go
index 4d24bc70..98e09c78 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -178,7 +178,14 @@ type ClaudeRequest struct {
type Thinking struct {
Type string `json:"type"`
- BudgetTokens int `json:"budget_tokens"`
+ BudgetTokens *int `json:"budget_tokens,omitempty"`
+}
+
+func (c *Thinking) GetBudgetTokens() int {
+ if c.BudgetTokens == nil {
+ return 0
+ }
+ return *c.BudgetTokens
}
func (c *ClaudeRequest) IsStringSystem() bool {
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 10e10332..c8355e54 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -58,6 +58,8 @@ type GeneralOpenAIRequest struct {
// OpenRouter Params
Usage json.RawMessage `json:"usage,omitempty"`
Reasoning json.RawMessage `json:"reasoning,omitempty"`
+ // Ali Qwen Params
+ VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
}
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
diff --git a/main.go b/main.go
index c286650f..30ba8092 100644
--- a/main.go
+++ b/main.go
@@ -105,10 +105,12 @@ func main() {
model.InitChannelCache()
}()
- go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
+ // 热更新配置
+ go model.SyncOptions(common.SyncFrequency)
+
// 数据看板
go model.UpdateQuotaData()
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1bfe1821..5d1c3641 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -49,8 +49,10 @@ func Distribute() func(c *gin.Context) {
}
// check group in common.GroupRatio
if !setting.ContainsGroupRatio(tokenGroup) {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
}
userGroup = tokenGroup
}
@@ -95,9 +97,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
+ var selectGroup string
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
diff --git a/model/ability.go b/model/ability.go
index 38b0bd73..96a9ef6a 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -8,6 +8,7 @@ import (
"github.com/samber/lo"
"gorm.io/gorm"
+ "gorm.io/gorm/clause"
)
type Ability struct {
@@ -23,7 +24,7 @@ type Ability struct {
func GetGroupModels(group string) []string {
var models []string
// Find distinct models
- DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+ DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -41,16 +42,12 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
- Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
- Order("priority DESC"). // 按优先级降序排序
+ Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
+ Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
@@ -75,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
- channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+ maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
+ channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
- channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+ channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
}
}
@@ -133,9 +126,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -152,7 +151,7 @@ func (channel *Channel) AddAbilities() error {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
- err := DB.Create(&chunk).Error
+ err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
return err
}
@@ -194,9 +193,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
// Then add new abilities
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -212,7 +217,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
- err = tx.Create(&chunk).Error
+ err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
diff --git a/model/cache.go b/model/cache.go
index e2f83e22..3e5eb4c4 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
+ "one-api/setting"
"sort"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ if common.DebugEnabled {
+ println("autoGroup:", autoGroup)
+ }
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ if common.DebugEnabled {
+ println("selectGroup:", selectGroup)
+ }
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ if channel == nil {
+ return nil, group, errors.New("channel not found")
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
diff --git a/model/channel.go b/model/channel.go
index ed7a0a7e..b5503eee 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -583,3 +583,17 @@ func BatchSetChannelTag(ids []int, tag *string) error {
// 提交事务
return tx.Commit().Error
}
+
+// CountAllChannels returns total channels in DB
+func CountAllChannels() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Count(&total).Error
+ return total, err
+}
+
+// CountAllTags returns number of non-empty distinct tags
+func CountAllTags() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
+ return total, err
+}
diff --git a/model/log.go b/model/log.go
index 0a891fcd..b3fd1ad2 100644
--- a/model/log.go
+++ b/model/log.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
+ "one-api/constant"
"os"
"strings"
"time"
@@ -32,6 +33,7 @@ type Log struct {
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
+ Ip string `json:"ip" gorm:"index;default:''"`
Other string `json:"other"`
}
@@ -61,7 +63,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
- if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+ if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -95,6 +97,15 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+ if vb, ok := v.(bool); ok && vb {
+ needRecordIp = true
+ }
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
@@ -111,7 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
- Other: otherStr,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -128,6 +145,15 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+ if vb, ok := v.(bool); ok && vb {
+ needRecordIp = true
+ }
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
@@ -144,7 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
- Other: otherStr,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -184,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -195,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return nil, 0, err
}
- channelIds := make([]int, 0)
+ channelIdsMap := make(map[int]struct{})
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
- channelIds = append(channelIds, log.ChannelId)
+ channelIdsMap[log.ChannelId] = struct{}{}
}
}
+
+ channelIds := make([]int, 0, len(channelIdsMap))
+ for channelId := range channelIdsMap {
+ channelIds = append(channelIds, channelId)
+ }
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
@@ -242,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -303,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where(groupCol+" = ?", group)
- rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
+ tx = tx.Where(logGroupCol+" = ?", group)
+ rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)
diff --git a/model/main.go b/model/main.go
index 61d6bb10..965bba93 100644
--- a/model/main.go
+++ b/model/main.go
@@ -1,6 +1,7 @@
package model
import (
+ "fmt"
"log"
"one-api/common"
"one-api/constant"
@@ -15,18 +16,39 @@ import (
"gorm.io/gorm"
)
-var groupCol string
-var keyCol string
+var commonGroupCol string
+var commonKeyCol string
+var commonTrueVal string
+var commonFalseVal string
+
+var logKeyCol string
+var logGroupCol string
func initCol() {
+ // init common column names
if common.UsingPostgreSQL {
- groupCol = `"group"`
- keyCol = `"key"`
-
+ commonGroupCol = `"group"`
+ commonKeyCol = `"key"`
+ commonTrueVal = "true"
+ commonFalseVal = "false"
} else {
- groupCol = "`group`"
- keyCol = "`key`"
+ commonGroupCol = "`group`"
+ commonKeyCol = "`key`"
+ commonTrueVal = "1"
+ commonFalseVal = "0"
}
+ if os.Getenv("LOG_SQL_DSN") != "" {
+ switch common.LogSqlType {
+ case common.DatabaseTypePostgreSQL:
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ default:
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ }
+ // log sql type and database type
+ common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
@@ -83,7 +105,7 @@ func CheckSetup() {
}
}
-func chooseDB(envName string) (*gorm.DB, error) {
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
defer func() {
initCol()
}()
@@ -92,7 +114,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
- common.UsingPostgreSQL = true
+ if !isLog {
+ common.UsingPostgreSQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypePostgreSQL
+ }
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +128,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
- common.UsingSQLite = true
+ if !isLog {
+ common.UsingSQLite = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeSQLite
+ }
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -117,7 +147,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
- common.UsingMySQL = true
+ if !isLog {
+ common.UsingMySQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeMySQL
+ }
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -131,7 +165,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
func InitDB() (err error) {
- db, err := chooseDB("SQL_DSN")
+ db, err := chooseDB("SQL_DSN", false)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -149,7 +183,7 @@ func InitDB() (err error) {
return nil
}
if common.UsingMySQL {
- _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+ //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
@@ -165,7 +199,7 @@ func InitLogDB() (err error) {
LOG_DB = DB
return
}
- db, err := chooseDB("LOG_SQL_DSN")
+ db, err := chooseDB("LOG_SQL_DSN", true)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -198,54 +232,73 @@ func InitLogDB() (err error) {
}
func migrateDB() error {
- err := DB.AutoMigrate(&Channel{})
+ if !common.UsingPostgreSQL {
+ return migrateDBFast()
+ }
+ err := DB.AutoMigrate(
+ &Channel{},
+ &Token{},
+ &User{},
+ &Option{},
+ &Redemption{},
+ &Ability{},
+ &Log{},
+ &Midjourney{},
+ &TopUp{},
+ &QuotaData{},
+ &Task{},
+ &Setup{},
+ )
if err != nil {
return err
}
- err = DB.AutoMigrate(&Token{})
- if err != nil {
- return err
+ return nil
+}
+
+func migrateDBFast() error {
+ var wg sync.WaitGroup
+ errChan := make(chan error, 12) // Buffer size matches number of migrations
+
+ migrations := []struct {
+ model interface{}
+ name string
+ }{
+ {&Channel{}, "Channel"},
+ {&Token{}, "Token"},
+ {&User{}, "User"},
+ {&Option{}, "Option"},
+ {&Redemption{}, "Redemption"},
+ {&Ability{}, "Ability"},
+ {&Log{}, "Log"},
+ {&Midjourney{}, "Midjourney"},
+ {&TopUp{}, "TopUp"},
+ {&QuotaData{}, "QuotaData"},
+ {&Task{}, "Task"},
+ {&Setup{}, "Setup"},
}
- err = DB.AutoMigrate(&User{})
- if err != nil {
- return err
+
+ for _, m := range migrations {
+ wg.Add(1)
+ go func(model interface{}, name string) {
+ defer wg.Done()
+ if err := DB.AutoMigrate(model); err != nil {
+ errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+ }
+ }(m.model, m.name)
}
- err = DB.AutoMigrate(&Option{})
- if err != nil {
- return err
+
+ // Wait for all migrations to complete
+ wg.Wait()
+ close(errChan)
+
+ // Check for any errors
+ for err := range errChan {
+ if err != nil {
+ return err
+ }
}
- err = DB.AutoMigrate(&Redemption{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Ability{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Log{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Midjourney{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&TopUp{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&QuotaData{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Task{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
- //err = createRootAccountIfNeed()
- return err
+ return nil
}
func migrateLOGDB() error {
diff --git a/model/midjourney.go b/model/midjourney.go
index 5f85abfd..e8140447 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
Where("id in (?)", taskIDs).
Updates(params).Error
}
+
+// CountAllTasks returns total midjourney tasks for admin query
+func CountAllTasks(queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// CountAllUserTask returns total midjourney tasks for user
+func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/option.go b/model/option.go
index 42949e8b..1391b203 100644
--- a/model/option.go
+++ b/model/option.go
@@ -76,6 +76,8 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -98,6 +100,7 @@ func InitOptionMap() {
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+ common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -122,9 +125,6 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
- common.OptionMap["ApiInfo"] = ""
- common.OptionMap["UptimeKumaUrl"] = ""
- common.OptionMap["UptimeKumaSlug"] = ""
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -194,7 +194,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
- if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -263,6 +263,8 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = boolValue
}
}
switch key {
@@ -289,6 +291,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -357,6 +361,8 @@ func updateOptionMap(key string, value string) (err error) {
err = operation_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
+ case "GroupGroupRatio":
+ err = setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
diff --git a/model/redemption.go b/model/redemption.go
index 89c4ac8c..bf237668 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -21,6 +21,7 @@ type Redemption struct {
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
}
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
@@ -131,6 +132,9 @@ func Redeem(key string, userId int) (quota int, err error) {
if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
+ if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
+ return errors.New("该兑换码已过期")
+ }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil {
return err
@@ -162,7 +166,7 @@ func (redemption *Redemption) SelectUpdate() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (redemption *Redemption) Update() error {
var err error
- err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
+ err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
return err
}
@@ -183,3 +187,9 @@ func DeleteRedemptionById(id int) (err error) {
}
return redemption.Delete()
}
+
+func DeleteInvalidRedemptions() (int64, error) {
+ now := common.GetTimestamp()
+ result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
+ return result.RowsAffected, result.Error
+}
diff --git a/model/task.go b/model/task.go
index df221edf..9e4177ba 100644
--- a/model/task.go
+++ b/model/task.go
@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}
+
+// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
+func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// TaskCountAllUserTask returns total tasks for given user
+func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{}).Where("user_id = ?", userId)
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/token.go b/model/token.go
index 8587ea62..2ed2c09a 100644
--- a/model/token.go
+++ b/model/token.go
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+ err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Where(keyCol+" = ?", key).First(&token).Error
+ err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
return token, err
}
@@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
+
+// CountUserTokens returns total number of tokens for the given user, used for pagination
+func CountUserTokens(userId int) (int64, error) {
+ var total int64
+ err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
+ return total, err
+}
diff --git a/model/token_cache.go b/model/token_cache.go
index b2e0c951..a4b0beae 100644
--- a/model/token_cache.go
+++ b/model/token_cache.go
@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
- err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
+ err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}
diff --git a/model/user.go b/model/user.go
index 1a3372aa..6a695457 100644
--- a/model/user.go
+++ b/model/user.go
@@ -41,6 +41,7 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
+ Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
}
func (user *User) ToBaseUser() *UserBase {
@@ -175,7 +176,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -184,7 +185,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
+ "remark": newUser.Remark,
}
if updatePassword {
updates["password"] = newUser.Password
@@ -615,7 +617,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
+ err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
if err != nil {
return "", err
}
diff --git a/model/user_cache.go b/model/user_cache.go
index d74877bd..e673defc 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
+ time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
)
}
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index cb2c75b1..ba20adea 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
Model string
ResponseText strings.Builder
Usage *dto.Usage
+ Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
- // message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
+
+ // message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
+ if claudeResponse.Delta.Thinking != "" {
+ claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+ }
} else if claudeResponse.Type == "message_delta" {
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ // 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
+ // 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+ // 判断是否完整
+ claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
+ FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
if requestMode == RequestModeCompletion {
- claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
- claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
- claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
- claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
- if claudeResponse.Usage.InputTokens > 0 {
- // 不叠加,只取最新的
- claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
- }
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
+
+ if requestMode == RequestModeCompletion {
+ claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ if claudeInfo.Usage.PromptTokens == 0 {
+ //上游出错
+ }
+ if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+ if common.DebugEnabled {
+ common.SysError("claude response usage is not complete, maybe upstream error")
+ }
+ claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+ }
+ }
+
if info.RelayFormat == relaycommon.RelayFormatClaude {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- // 说明流模式建立失败,可能为官方出错
- if claudeInfo.Usage.PromptTokens == 0 {
- //usage.PromptTokens = info.PromptTokens
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
- }
+ //
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- if claudeInfo.Usage.PromptTokens == 0 {
- //上游出错
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
- }
+
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go
index 10c4328b..8a044bf2 100644
--- a/relay/channel/cohere/relay-cohere.go
+++ b/relay/channel/cohere/relay-cohere.go
@@ -3,7 +3,6 @@ package cohere
import (
"bufio"
"encoding/json"
- "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index e6f66d5f..a81eb3a9 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -72,8 +72,11 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // suffix -thinking and -nothinking
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.OriginModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index e2288faf..635041d7 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -12,6 +12,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
+ "strconv"
"strings"
"unicode/utf8"
@@ -36,6 +37,47 @@ var geminiSupportedMimeTypes = map[string]bool{
"video/flv": true,
}
+// Gemini 允许的思考预算范围
+const (
+ pro25MinBudget = 128
+ pro25MaxBudget = 32768
+ flash25MaxBudget = 24576
+ flash25LiteMinBudget = 512
+ flash25LiteMaxBudget = 24576
+)
+
+// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
+func clampThinkingBudget(modelName string, budget int) int {
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+ is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+
+ if is25FlashLite {
+ if budget < flash25LiteMinBudget {
+ return flash25LiteMinBudget
+ }
+ if budget > flash25LiteMaxBudget {
+ return flash25LiteMaxBudget
+ }
+ } else if isNew25Pro {
+ if budget < pro25MinBudget {
+ return pro25MinBudget
+ }
+ if budget > pro25MaxBudget {
+ return pro25MaxBudget
+ }
+ } else { // 其他模型
+ if budget < 0 {
+ return 0
+ }
+ if budget > flash25MaxBudget {
+ return flash25MaxBudget
+ }
+ }
+ return budget
+}
+
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
@@ -57,16 +99,31 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
- // 硬编码不支持 ThinkingBudget 的旧模型
+ modelName := info.OriginModelName
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+ is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.SplitN(modelName, "-thinking-", 2)
+ if len(parts) == 2 && parts[1] != "" {
+ if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
+ clampedBudget := clampThinkingBudget(modelName, budgetTokens)
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(clampedBudget),
+ IncludeThoughts: true,
+ }
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-thinking") {
unsupportedModels := []string{
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-preview-03-25",
}
-
isUnsupported := false
for _, unsupportedModel := range unsupportedModels {
- if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
+ if strings.HasPrefix(modelName, unsupportedModel) {
isUnsupported = true
break
}
@@ -78,39 +135,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
} else {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
-
- // 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
- isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
- !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
- !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
-
- if isNew25Pro {
- // 新的2.5pro模型:ThinkingBudget范围为128-32768
- if budgetTokens == 0 || budgetTokens < 128 {
- budgetTokens = 128
- } else if budgetTokens > 32768 {
- budgetTokens = 32768
- }
- } else {
- // 其他模型:ThinkingBudget范围为0-24576
- if budgetTokens == 0 || budgetTokens > 24576 {
- budgetTokens = 24576
- }
- }
-
+ clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(int(budgetTokens)),
+ ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
- // 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
- isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
- !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
- !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
-
- if !isNew25Pro {
- // 只有非新2.5pro模型才支持-nothinking
+ } else if strings.HasSuffix(modelName, "-nothinking") {
+ if !isNew25Pro && !is25FlashLite {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
@@ -283,7 +315,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
- return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
+ url := part.GetImageMedia().Url
+ return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
@@ -611,9 +644,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
}
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
+func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
@@ -754,7 +787,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// responseText := ""
- id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
var imageCount int
@@ -849,7 +882,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
StatusCode: resp.StatusCode,
}, nil
}
- fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
+ fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index cef958b2..ea24d811 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
+
+ // 特殊处理 responses API
+ if info.RelayMode == constant.RelayModeResponses {
+ requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
+ return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+ }
+
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 2e3d8df1..4dc0fc60 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -8,6 +8,7 @@ import (
"math"
"mime/multipart"
"net/http"
+ "path/filepath"
"one-api/common"
"one-api/constant"
"one-api/dto"
@@ -345,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
-
+ ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
+ defer reqFp.Close()
- tmpFp, err := os.CreateTemp("", "audio-*")
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
@@ -365,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
return 0, errors.WithStack(err)
}
- duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+ duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index 5c398b5e..0c6f8641 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -2,7 +2,6 @@ package palm
import (
"encoding/json"
- "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index fb68a88a..e8805255 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -98,7 +98,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
// BudgetTokens 为 max_tokens 的 80%
textRequest.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index f4fc3c1e..a842a58d 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -61,6 +61,7 @@ type RelayInfo struct {
TokenKey string
UserId int
Group string
+ UserGroup string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
@@ -204,6 +205,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenKey: tokenKey,
UserId: userId,
Group: group,
+ UserGroup: c.GetString(constant.ContextKeyUserGroup),
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
diff --git a/relay/helper/price.go b/relay/helper/price.go
index 89efa1da..326790b4 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -2,14 +2,20 @@ package helper
import (
"fmt"
- "github.com/gin-gonic/gin"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
+
+ "github.com/gin-gonic/gin"
)
+type GroupRatioInfo struct {
+ GroupRatio float64
+ GroupSpecialRatio float64
+}
+
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -17,18 +23,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
- GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
+ GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+}
+
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
+ groupRatioInfo := GroupRatioInfo{
+ GroupRatio: 1.0, // default ratio
+ GroupSpecialRatio: 1.0, // default user group ratio
+ }
+
+ // check auto group
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("final group: %s", autoGroup))
+ }
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ // check user group special ratio
+ userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
+ if ok {
+ // user group special ratio
+ groupRatioInfo.GroupSpecialRatio = userGroupRatio
+ groupRatioInfo.GroupRatio = userGroupRatio
+ } else {
+ // normal group ratio
+ groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
+ }
+
+ return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
- groupRatio := setting.GetGroupRatio(info.Group)
+
+ groupRatioInfo := HandleGroupRatio(c, info)
+
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -58,17 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
- ratio := modelRatio * groupRatio
+ ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
- GroupRatio: groupRatio,
+ GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go
index 93a2b7aa..21cf5e12 100644
--- a/relay/relay-gemini.go
+++ b/relay/relay-gemini.go
@@ -136,6 +136,20 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
adaptor.Init(relayInfo)
+ // Clean up empty system instruction
+ if req.SystemInstructions != nil {
+ hasContent := false
+ for _, part := range req.SystemInstructions.Parts {
+ if part.Text != "" {
+ hasContent = true
+ break
+ }
+ }
+ if !hasContent {
+ req.SystemInstructions = nil
+ }
+ }
+
requestBody, err := json.Marshal(req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
diff --git a/relay/relay-image.go b/relay/relay-image.go
index dc63cce8..197a8af6 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
- quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+ quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
diff --git a/relay/relay-text.go b/relay/relay-text.go
index a48a664a..24fb8155 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
- if textRequest.WebSearchOptions != nil {
- c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
- }
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
+ if textRequest.WebSearchOptions != nil {
+ c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
+ }
+
if setting.ShouldCheckPromptSensitive() {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -361,7 +362,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
// Convert values to decimal for precise calculation
@@ -510,7 +511,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
diff --git a/relay/websocket.go b/relay/websocket.go
index c815eb71..571f3a82 100644
--- a/relay/websocket.go
+++ b/relay/websocket.go
@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
- //relayInfo.UpstreamModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
- //if constant.ShouldCheckPromptSensitive() {
- // err = checkRequestSensitive(textRequest, relayInfo)
- // if err != nil {
- // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
- // }
- //}
-
- //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
- //// count messages token error 计算promptTokens错误
- //if err != nil {
- // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
- //}
- //
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
- // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
- //}
- modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- relayInfo.UsePrice = true
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
- userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ userQuota, priceData, "")
return nil
}
diff --git a/router/api-router.go b/router/api-router.go
index 0ab8be7f..45930246 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -81,6 +81,7 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
+ optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
@@ -126,6 +127,7 @@ func SetApiRouter(router *gin.Engine) {
redemptionRoute.GET("/:id", controller.GetRedemption)
redemptionRoute.POST("/", controller.AddRedemption)
redemptionRoute.PUT("/", controller.UpdateRedemption)
+ redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
}
logRoute := apiRouter.Group("/log")
diff --git a/service/channel.go b/service/channel.go
index e3a76af4..746e9a34 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
case "billing_not_active":
return true
+ case "pre_consume_token_quota_failed":
+ return true
}
switch err.Error.Type {
case "insufficient_quota":
diff --git a/service/convert.go b/service/convert.go
index 53aefb62..df7acf0d 100644
--- a/service/convert.go
+++ b/service/convert.go
@@ -21,10 +21,10 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
- if claudeRequest.Thinking != nil {
+ if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
if isOpenRouter {
reasoning := openrouter.RequestReasoning{
- MaxTokens: claudeRequest.Thinking.BudgetTokens,
+ MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
}
reasoningJSON, err := json.Marshal(reasoning)
if err != nil {
diff --git a/service/error.go b/service/error.go
index 1bf5992b..f3d8a17d 100644
--- a/service/error.go
+++ b/service/error.go
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
+ if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
}
openAIError := dto.OpenAIError{
Message: text,
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
+ if !strings.HasPrefix(lowerText, "get file base64 from url") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
}
claudeError := dto.ClaudeError{
Message: text,
diff --git a/service/file_decoder.go b/service/file_decoder.go
index bbb188f8..c1d4fb0c 100644
--- a/service/file_decoder.go
+++ b/service/file_decoder.go
@@ -4,8 +4,10 @@ import (
"encoding/base64"
"fmt"
"io"
+ "one-api/common"
"one-api/constant"
"one-api/dto"
+ "strings"
)
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
+ mimeType := resp.Header.Get("Content-Type")
+ if len(strings.Split(mimeType, ";")) > 1 {
+ // If Content-Type has parameters, take the first part
+ mimeType = strings.Split(mimeType, ";")[0]
+ }
+ if mimeType == "application/octet-stream" {
+ if common.DebugEnabled {
+ println("MIME type is application/octet-stream, trying to guess from URL or filename")
+ }
+ // try to guess the MIME type from the url last segment
+ urlParts := strings.Split(url, "/")
+ if len(urlParts) > 0 {
+ lastSegment := urlParts[len(urlParts)-1]
+ if strings.Contains(lastSegment, ".") {
+ // Extract the file extension
+ filename := strings.Split(lastSegment, ".")
+ if len(filename) > 1 {
+ ext := strings.ToLower(filename[len(filename)-1])
+ // Guess MIME type based on file extension
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ }
+ } else {
+ // try to guess the MIME type from the file extension
+ fileName := resp.Header.Get("Content-Disposition")
+ if fileName != "" {
+ // Extract the filename from the Content-Disposition header
+ parts := strings.Split(fileName, ";")
+ for _, part := range parts {
+ if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
+ fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+ // Remove quotes if present
+ if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
+ fileName = fileName[1 : len(fileName)-1]
+ }
+ // Guess MIME type based on file extension
+ if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ break
+ }
+ }
+ }
+ }
+ }
+
return &dto.LocalFileData{
Base64Data: base64Data,
- MimeType: resp.Header.Get("Content-Type"),
+ MimeType: mimeType,
Size: int64(len(fileBytes)),
}, nil
}
+
+func GetMimeTypeByExtension(ext string) string {
+ // Convert to lowercase for case-insensitive comparison
+ ext = strings.ToLower(ext)
+ switch ext {
+ // Text files
+ case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
+ return "text/plain"
+
+ // Image files
+ case "jpg", "jpeg":
+ return "image/jpeg"
+ case "png":
+ return "image/png"
+ case "gif":
+ return "image/gif"
+
+ // Audio files
+ case "mp3":
+ return "audio/mp3"
+ case "wav":
+ return "audio/wav"
+ case "mpeg":
+ return "audio/mpeg"
+
+ // Video files
+ case "mp4":
+ return "video/mp4"
+ case "wmv":
+ return "video/wmv"
+ case "flv":
+ return "video/flv"
+ case "mov":
+ return "video/mov"
+ case "mpg":
+ return "video/mpg"
+ case "avi":
+ return "video/avi"
+ case "mpegps":
+ return "video/mpegps"
+
+ // Document files
+ case "pdf":
+ return "application/pdf"
+
+ default:
+ return "application/octet-stream" // Default for unknown types
+ }
+}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 75457b97..1edc9073 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -8,7 +8,7 @@ import (
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
+ cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
@@ -16,6 +16,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["cache_tokens"] = cacheTokens
other["cache_ratio"] = cacheRatio
other["model_price"] = modelPrice
+ other["user_group_ratio"] = userGroupRatio
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
@@ -30,8 +31,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
return other
}
-func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
@@ -42,8 +43,8 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
return info
}
-func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
@@ -55,8 +56,8 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
+ cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
info["claude"] = true
info["cache_creation_tokens"] = cacheCreationTokens
info["cache_creation_ratio"] = cacheCreationRatio
diff --git a/service/quota.go b/service/quota.go
index 43297b4a..8c7ed07e 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
+ "log"
"math"
"one-api/common"
constant2 "one-api/constant"
@@ -97,6 +98,19 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio, _ := operation_setting.GetModelRatio(modelName)
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
+ userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
+ if ok {
+ actualGroupRatio = userGroupRatio
+ }
+
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -109,7 +123,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
- GroupRatio: groupRatio,
+ GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -131,8 +145,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -146,6 +159,11 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
+
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -190,7 +208,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -206,9 +224,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
-
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -262,7 +279,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -304,7 +321,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
@@ -360,7 +377,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/console.go b/setting/console.go
deleted file mode 100644
index 94023666..00000000
--- a/setting/console.go
+++ /dev/null
@@ -1,327 +0,0 @@
-package setting
-
-import (
- "encoding/json"
- "fmt"
- "net/url"
- "one-api/common"
- "regexp"
- "sort"
- "strings"
- "time"
-)
-
-// ValidateConsoleSettings 验证控制台设置信息格式
-func ValidateConsoleSettings(settingsStr string, settingType string) error {
- if settingsStr == "" {
- return nil // 空字符串是合法的
- }
-
- switch settingType {
- case "ApiInfo":
- return validateApiInfo(settingsStr)
- case "Announcements":
- return validateAnnouncements(settingsStr)
- case "FAQ":
- return validateFAQ(settingsStr)
- default:
- return fmt.Errorf("未知的设置类型:%s", settingType)
- }
-}
-
-// validateApiInfo 验证API信息格式
-func validateApiInfo(apiInfoStr string) error {
- var apiInfoList []map[string]interface{}
- if err := json.Unmarshal([]byte(apiInfoStr), &apiInfoList); err != nil {
- return fmt.Errorf("API信息格式错误:%s", err.Error())
- }
-
- // 验证数组长度
- if len(apiInfoList) > 50 {
- return fmt.Errorf("API信息数量不能超过50个")
- }
-
- // 允许的颜色值
- validColors := map[string]bool{
- "blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
- "red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
- "light-green": true, "teal": true, "light-blue": true, "indigo": true,
- "violet": true, "grey": true,
- }
-
- // URL正则表达式,支持域名和IP地址格式
- // 域名格式:https://example.com 或 https://sub.example.com:8080
- // IP地址格式:https://192.168.1.1 或 https://192.168.1.1:8080
- urlRegex := regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?::[0-9]{1,5})?(?:/.*)?$`)
-
- for i, apiInfo := range apiInfoList {
- // 检查必填字段
- urlStr, ok := apiInfo["url"].(string)
- if !ok || urlStr == "" {
- return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
- }
-
- route, ok := apiInfo["route"].(string)
- if !ok || route == "" {
- return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
- }
-
- description, ok := apiInfo["description"].(string)
- if !ok || description == "" {
- return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
- }
-
- color, ok := apiInfo["color"].(string)
- if !ok || color == "" {
- return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
- }
-
- // 验证URL格式
- if !urlRegex.MatchString(urlStr) {
- return fmt.Errorf("第%d个API信息的URL格式不正确", i+1)
- }
-
- // 验证URL可解析性
- if _, err := url.Parse(urlStr); err != nil {
- return fmt.Errorf("第%d个API信息的URL无法解析:%s", i+1, err.Error())
- }
-
- // 验证字段长度
- if len(urlStr) > 500 {
- return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
- }
-
- if len(route) > 100 {
- return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
- }
-
- if len(description) > 200 {
- return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
- }
-
- // 验证颜色值
- if !validColors[color] {
- return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
- }
-
- // 检查并过滤危险字符(防止XSS)
- dangerousChars := []string{"