diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go
index ef5d1935..fcfcb0c3 100644
--- a/common/limiter/limiter.go
+++ b/common/limiter/limiter.go
@@ -5,7 +5,7 @@ import (
_ "embed"
"fmt"
"github.com/go-redis/redis/v8"
- "one-api/common"
+ "one-api/logger"
"sync"
)
@@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter {
// 预加载脚本
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
if err != nil {
- common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
+ logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
}
instance = &RedisLimiter{
client: r,
diff --git a/common/logger.go b/common/logger.go
index 0f6dc3c3..478015f0 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -1,52 +1,12 @@
package common
import (
- "context"
- "encoding/json"
"fmt"
- "github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
- "io"
- "log"
"os"
- "path/filepath"
- "sync"
"time"
)
-const (
- loggerINFO = "INFO"
- loggerWarn = "WARN"
- loggerError = "ERR"
-)
-
-const maxLogCount = 1000000
-
-var logCount int
-var setupLogLock sync.Mutex
-var setupLogWorking bool
-
-func SetupLogger() {
- if *LogDir != "" {
- ok := setupLogLock.TryLock()
- if !ok {
- log.Println("setup log is already working")
- return
- }
- defer func() {
- setupLogLock.Unlock()
- setupLogWorking = false
- }()
- logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
- fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
- if err != nil {
- log.Fatal("failed to open log file")
- }
- gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
- gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
- }
-}
-
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
@@ -57,67 +17,8 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
-func LogInfo(ctx context.Context, msg string) {
- logHelper(ctx, loggerINFO, msg)
-}
-
-func LogWarn(ctx context.Context, msg string) {
- logHelper(ctx, loggerWarn, msg)
-}
-
-func LogError(ctx context.Context, msg string) {
- logHelper(ctx, loggerError, msg)
-}
-
-func logHelper(ctx context.Context, level string, msg string) {
- writer := gin.DefaultErrorWriter
- if level == loggerINFO {
- writer = gin.DefaultWriter
- }
- id := ctx.Value(RequestIdKey)
- if id == nil {
- id = "SYSTEM"
- }
- now := time.Now()
- _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
- logCount++ // we don't need accurate count, so no lock here
- if logCount > maxLogCount && !setupLogWorking {
- logCount = 0
- setupLogWorking = true
- gopool.Go(func() {
- SetupLogger()
- })
- }
-}
-
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
-
-func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d 点额度", quota)
- }
-}
-
-func FormatQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d", quota)
- }
-}
-
-// LogJson 仅供测试使用 only for test
-func LogJson(ctx context.Context, msg string, obj any) {
- jsonStr, err := json.Marshal(obj)
- if err != nil {
- LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
- return
- }
- LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
-}
diff --git a/constant/context_key.go b/constant/context_key.go
index b82b19e7..569a0373 100644
--- a/constant/context_key.go
+++ b/constant/context_key.go
@@ -3,6 +3,8 @@ package constant
type ContextKey string
const (
+ ContextKeyPromptTokens ContextKey = "prompt_tokens"
+
ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyRequestStartTime ContextKey = "request_start_time"
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 5152e060..bbf0f97a 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/logger"
"one-api/model"
"one-api/service"
"one-api/setting"
@@ -485,8 +486,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("updating all channels")
+ logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
- common.SysLog("channels update done")
+ logger.SysLog("channels update done")
}
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 026a863b..ec2e6226 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -13,6 +13,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/middleware"
"one-api/model"
"one-api/relay"
@@ -159,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
// 创建一个用于日志的 info 副本,移除 ApiKey
logInfo := *info
logInfo.ApiKey = ""
- common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
+ logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
if err != nil {
@@ -279,7 +280,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
Group: info.UsingGroup,
Other: other,
})
- common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+ logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return testResult{
context: c,
localErr: nil,
@@ -461,13 +462,13 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
if frequency <= 0 {
- common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+ logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
return
}
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("testing all channels")
+ logger.SysLog("testing all channels")
_ = testAllChannels(false)
- common.SysLog("channel test finished")
+ logger.SysLog("channel test finished")
}
}
diff --git a/controller/console_migrate.go b/controller/console_migrate.go
index d25f199b..d21f5e21 100644
--- a/controller/console_migrate.go
+++ b/controller/console_migrate.go
@@ -3,101 +3,101 @@
package controller
import (
- "encoding/json"
- "net/http"
- "one-api/common"
- "one-api/model"
- "github.com/gin-gonic/gin"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/logger"
+ "one-api/model"
)
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
func MigrateConsoleSetting(c *gin.Context) {
- // 读取全部 option
- opts, err := model.AllOption()
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
- return
- }
- // 建立 map
- valMap := map[string]string{}
- for _, o := range opts {
- valMap[o.Key] = o.Value
- }
+ // 读取全部 option
+ opts, err := model.AllOption()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // 建立 map
+ valMap := map[string]string{}
+ for _, o := range opts {
+ valMap[o.Key] = o.Value
+ }
- // 处理 APIInfo
- if v := valMap["ApiInfo"]; v != "" {
- var arr []map[string]interface{}
- if err := json.Unmarshal([]byte(v), &arr); err == nil {
- if len(arr) > 50 {
- arr = arr[:50]
- }
- bytes, _ := json.Marshal(arr)
- model.UpdateOption("console_setting.api_info", string(bytes))
- }
- model.UpdateOption("ApiInfo", "")
- }
- // Announcements 直接搬
- if v := valMap["Announcements"]; v != "" {
- model.UpdateOption("console_setting.announcements", v)
- model.UpdateOption("Announcements", "")
- }
- // FAQ 转换
- if v := valMap["FAQ"]; v != "" {
- var arr []map[string]interface{}
- if err := json.Unmarshal([]byte(v), &arr); err == nil {
- out := []map[string]interface{}{}
- for _, item := range arr {
- q, _ := item["question"].(string)
- if q == "" {
- q, _ = item["title"].(string)
- }
- a, _ := item["answer"].(string)
- if a == "" {
- a, _ = item["content"].(string)
- }
- if q != "" && a != "" {
- out = append(out, map[string]interface{}{"question": q, "answer": a})
- }
- }
- if len(out) > 50 {
- out = out[:50]
- }
- bytes, _ := json.Marshal(out)
- model.UpdateOption("console_setting.faq", string(bytes))
- }
- model.UpdateOption("FAQ", "")
- }
- // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
- url := valMap["UptimeKumaUrl"]
- slug := valMap["UptimeKumaSlug"]
- if url != "" && slug != "" {
- // 仅当同时存在 URL 与 Slug 时才进行迁移
- groups := []map[string]interface{}{
- {
- "id": 1,
- "categoryName": "old",
- "url": url,
- "slug": slug,
- "description": "",
- },
- }
- bytes, _ := json.Marshal(groups)
- model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
- }
- // 清空旧键内容
- if url != "" {
- model.UpdateOption("UptimeKumaUrl", "")
- }
- if slug != "" {
- model.UpdateOption("UptimeKumaSlug", "")
- }
+ // 处理 APIInfo
+ if v := valMap["ApiInfo"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ if len(arr) > 50 {
+ arr = arr[:50]
+ }
+ bytes, _ := json.Marshal(arr)
+ model.UpdateOption("console_setting.api_info", string(bytes))
+ }
+ model.UpdateOption("ApiInfo", "")
+ }
+ // Announcements 直接搬
+ if v := valMap["Announcements"]; v != "" {
+ model.UpdateOption("console_setting.announcements", v)
+ model.UpdateOption("Announcements", "")
+ }
+ // FAQ 转换
+ if v := valMap["FAQ"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ out := []map[string]interface{}{}
+ for _, item := range arr {
+ q, _ := item["question"].(string)
+ if q == "" {
+ q, _ = item["title"].(string)
+ }
+ a, _ := item["answer"].(string)
+ if a == "" {
+ a, _ = item["content"].(string)
+ }
+ if q != "" && a != "" {
+ out = append(out, map[string]interface{}{"question": q, "answer": a})
+ }
+ }
+ if len(out) > 50 {
+ out = out[:50]
+ }
+ bytes, _ := json.Marshal(out)
+ model.UpdateOption("console_setting.faq", string(bytes))
+ }
+ model.UpdateOption("FAQ", "")
+ }
+ // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+ url := valMap["UptimeKumaUrl"]
+ slug := valMap["UptimeKumaSlug"]
+ if url != "" && slug != "" {
+ // 仅当同时存在 URL 与 Slug 时才进行迁移
+ groups := []map[string]interface{}{
+ {
+ "id": 1,
+ "categoryName": "old",
+ "url": url,
+ "slug": slug,
+ "description": "",
+ },
+ }
+ bytes, _ := json.Marshal(groups)
+ model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+ }
+ // 清空旧键内容
+ if url != "" {
+ model.UpdateOption("UptimeKumaUrl", "")
+ }
+ if slug != "" {
+ model.UpdateOption("UptimeKumaSlug", "")
+ }
- // 删除旧键记录
- oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
- model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+ // 删除旧键记录
+ oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+ model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
- // 重新加载 OptionMap
- model.InitOptionMap()
- common.SysLog("console setting migrated")
- c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
-}
\ No newline at end of file
+ // 重新加载 OptionMap
+ model.InitOptionMap()
+ logger.SysLog("console setting migrated")
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}
diff --git a/controller/github.go b/controller/github.go
index 881d6dc1..0715a8fe 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"strconv"
"time"
@@ -47,7 +48,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -63,7 +64,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 30a5a09a..a67d39c2 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -9,6 +9,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/service"
"one-api/setting"
@@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
continue
}
- common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+ logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
@@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
"progress": "100%",
})
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
- common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+ logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
@@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
}
for channelId, taskIds := range taskChannelM {
- common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+ logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
@@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
@@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
- common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -147,12 +148,12 @@ func UpdateMidjourneyTaskBulk() {
}
// 映射 VideoUrl
task.VideoUrl = responseItem.VideoUrl
-
+
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
task.VideoUrls = "[]" // 失败时设置为空数组
} else {
task.VideoUrls = string(videoUrlsStr)
@@ -160,10 +161,10 @@ func UpdateMidjourneyTaskBulk() {
} else {
task.VideoUrls = "" // 空值时清空字段
}
-
+
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
- common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+ logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
if task.Quota != 0 {
shouldReturnQuota = true
@@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
}
err = task.Update()
if err != nil {
- common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+ logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+ logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
diff --git a/controller/oidc.go b/controller/oidc.go
index df8ea1c4..1e3435a8 100644
--- a/controller/oidc.go
+++ b/controller/oidc.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
@@ -58,7 +59,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
}
res, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -69,7 +70,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
}
if oidcResponse.AccessToken == "" {
- common.SysError("OIDC 获取 Token 失败,请检查设置!")
+ logger.SysError("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
@@ -80,12 +81,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
- common.SysError("OIDC 获取用户信息失败!请检查设置!")
+ logger.SysError("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
@@ -95,7 +96,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
- common.SysError("OIDC 获取用户信息为空!请检查设置!")
+ logger.SysError("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go
index 0453870d..6fba0aac 100644
--- a/controller/ratio_sync.go
+++ b/controller/ratio_sync.go
@@ -1,474 +1,474 @@
package controller
import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
- "sync"
- "time"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/logger"
+ "strings"
+ "sync"
+ "time"
- "one-api/common"
- "one-api/dto"
- "one-api/model"
- "one-api/setting/ratio_setting"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting/ratio_setting"
- "github.com/gin-gonic/gin"
+ "github.com/gin-gonic/gin"
)
const (
- defaultTimeoutSeconds = 10
- defaultEndpoint = "/api/ratio_config"
- maxConcurrentFetches = 8
+ defaultTimeoutSeconds = 10
+ defaultEndpoint = "/api/ratio_config"
+ maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
- Name string `json:"name"`
- Data map[string]any `json:"data,omitempty"`
- Err string `json:"err,omitempty"`
+ Name string `json:"name"`
+ Data map[string]any `json:"data,omitempty"`
+ Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
- var req dto.UpstreamRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
- return
- }
+ var req dto.UpstreamRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
- if req.Timeout <= 0 {
- req.Timeout = defaultTimeoutSeconds
- }
+ if req.Timeout <= 0 {
+ req.Timeout = defaultTimeoutSeconds
+ }
- var upstreams []dto.UpstreamDTO
+ var upstreams []dto.UpstreamDTO
- if len(req.Upstreams) > 0 {
- for _, u := range req.Upstreams {
- if strings.HasPrefix(u.BaseURL, "http") {
- if u.Endpoint == "" {
- u.Endpoint = defaultEndpoint
- }
- u.BaseURL = strings.TrimRight(u.BaseURL, "/")
- upstreams = append(upstreams, u)
- }
- }
- } else if len(req.ChannelIDs) > 0 {
- intIds := make([]int, 0, len(req.ChannelIDs))
- for _, id64 := range req.ChannelIDs {
- intIds = append(intIds, int(id64))
- }
- dbChannels, err := model.GetChannelsByIds(intIds)
- if err != nil {
- common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
- c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
- return
- }
- for _, ch := range dbChannels {
- if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
- upstreams = append(upstreams, dto.UpstreamDTO{
- ID: ch.Id,
- Name: ch.Name,
- BaseURL: strings.TrimRight(base, "/"),
- Endpoint: "",
- })
- }
- }
- }
+ if len(req.Upstreams) > 0 {
+ for _, u := range req.Upstreams {
+ if strings.HasPrefix(u.BaseURL, "http") {
+ if u.Endpoint == "" {
+ u.Endpoint = defaultEndpoint
+ }
+ u.BaseURL = strings.TrimRight(u.BaseURL, "/")
+ upstreams = append(upstreams, u)
+ }
+ }
+ } else if len(req.ChannelIDs) > 0 {
+ intIds := make([]int, 0, len(req.ChannelIDs))
+ for _, id64 := range req.ChannelIDs {
+ intIds = append(intIds, int(id64))
+ }
+ dbChannels, err := model.GetChannelsByIds(intIds)
+ if err != nil {
+ logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+ return
+ }
+ for _, ch := range dbChannels {
+ if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+ upstreams = append(upstreams, dto.UpstreamDTO{
+ ID: ch.Id,
+ Name: ch.Name,
+ BaseURL: strings.TrimRight(base, "/"),
+ Endpoint: "",
+ })
+ }
+ }
+ }
- if len(upstreams) == 0 {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
- return
- }
+ if len(upstreams) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+ return
+ }
- var wg sync.WaitGroup
- ch := make(chan upstreamResult, len(upstreams))
+ var wg sync.WaitGroup
+ ch := make(chan upstreamResult, len(upstreams))
- sem := make(chan struct{}, maxConcurrentFetches)
+ sem := make(chan struct{}, maxConcurrentFetches)
- client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+ client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
- for _, chn := range upstreams {
- wg.Add(1)
- go func(chItem dto.UpstreamDTO) {
- defer wg.Done()
+ for _, chn := range upstreams {
+ wg.Add(1)
+ go func(chItem dto.UpstreamDTO) {
+ defer wg.Done()
- sem <- struct{}{}
- defer func() { <-sem }()
+ sem <- struct{}{}
+ defer func() { <-sem }()
- endpoint := chItem.Endpoint
- if endpoint == "" {
- endpoint = defaultEndpoint
- } else if !strings.HasPrefix(endpoint, "/") {
- endpoint = "/" + endpoint
- }
- fullURL := chItem.BaseURL + endpoint
+ endpoint := chItem.Endpoint
+ if endpoint == "" {
+ endpoint = defaultEndpoint
+ } else if !strings.HasPrefix(endpoint, "/") {
+ endpoint = "/" + endpoint
+ }
+ fullURL := chItem.BaseURL + endpoint
- uniqueName := chItem.Name
- if chItem.ID != 0 {
- uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
- }
+ uniqueName := chItem.Name
+ if chItem.ID != 0 {
+ uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
+ }
- ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
- defer cancel()
+ ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+ defer cancel()
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
- if err != nil {
- common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
- return
- }
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+ if err != nil {
+ logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
- resp, err := client.Do(httpReq)
- if err != nil {
- common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
- return
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
- ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
- return
- }
- // 兼容两种上游接口格式:
- // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
- // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
- var body struct {
- Success bool `json:"success"`
- Data json.RawMessage `json:"data"`
- Message string `json:"message"`
- }
+ resp, err := client.Do(httpReq)
+ if err != nil {
+ logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+ ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
+ return
+ }
+ // 兼容两种上游接口格式:
+ // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
+ // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
+ var body struct {
+ Success bool `json:"success"`
+ Data json.RawMessage `json:"data"`
+ Message string `json:"message"`
+ }
- if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
- common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
- return
- }
+ if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+ logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
- if !body.Success {
- ch <- upstreamResult{Name: uniqueName, Err: body.Message}
- return
- }
+ if !body.Success {
+ ch <- upstreamResult{Name: uniqueName, Err: body.Message}
+ return
+ }
- // 尝试按 type1 解析
- var type1Data map[string]any
- if err := json.Unmarshal(body.Data, &type1Data); err == nil {
- // 如果包含至少一个 ratioTypes 字段,则认为是 type1
- isType1 := false
- for _, rt := range ratioTypes {
- if _, ok := type1Data[rt]; ok {
- isType1 = true
- break
- }
- }
- if isType1 {
- ch <- upstreamResult{Name: uniqueName, Data: type1Data}
- return
- }
- }
+ // 尝试按 type1 解析
+ var type1Data map[string]any
+ if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+ // 如果包含至少一个 ratioTypes 字段,则认为是 type1
+ isType1 := false
+ for _, rt := range ratioTypes {
+ if _, ok := type1Data[rt]; ok {
+ isType1 = true
+ break
+ }
+ }
+ if isType1 {
+ ch <- upstreamResult{Name: uniqueName, Data: type1Data}
+ return
+ }
+ }
- // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
- var pricingItems []struct {
- ModelName string `json:"model_name"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- CompletionRatio float64 `json:"completion_ratio"`
- }
- if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
- common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
- return
- }
+ // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
+ var pricingItems []struct {
+ ModelName string `json:"model_name"`
+ QuotaType int `json:"quota_type"`
+ ModelRatio float64 `json:"model_ratio"`
+ ModelPrice float64 `json:"model_price"`
+ CompletionRatio float64 `json:"completion_ratio"`
+ }
+ if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+ logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
+ return
+ }
- modelRatioMap := make(map[string]float64)
- completionRatioMap := make(map[string]float64)
- modelPriceMap := make(map[string]float64)
+ modelRatioMap := make(map[string]float64)
+ completionRatioMap := make(map[string]float64)
+ modelPriceMap := make(map[string]float64)
- for _, item := range pricingItems {
- if item.QuotaType == 1 {
- modelPriceMap[item.ModelName] = item.ModelPrice
- } else {
- modelRatioMap[item.ModelName] = item.ModelRatio
- // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
- completionRatioMap[item.ModelName] = item.CompletionRatio
- }
- }
+ for _, item := range pricingItems {
+ if item.QuotaType == 1 {
+ modelPriceMap[item.ModelName] = item.ModelPrice
+ } else {
+ modelRatioMap[item.ModelName] = item.ModelRatio
+ // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
+ completionRatioMap[item.ModelName] = item.CompletionRatio
+ }
+ }
- converted := make(map[string]any)
+ converted := make(map[string]any)
- if len(modelRatioMap) > 0 {
- ratioAny := make(map[string]any, len(modelRatioMap))
- for k, v := range modelRatioMap {
- ratioAny[k] = v
- }
- converted["model_ratio"] = ratioAny
- }
+ if len(modelRatioMap) > 0 {
+ ratioAny := make(map[string]any, len(modelRatioMap))
+ for k, v := range modelRatioMap {
+ ratioAny[k] = v
+ }
+ converted["model_ratio"] = ratioAny
+ }
- if len(completionRatioMap) > 0 {
- compAny := make(map[string]any, len(completionRatioMap))
- for k, v := range completionRatioMap {
- compAny[k] = v
- }
- converted["completion_ratio"] = compAny
- }
+ if len(completionRatioMap) > 0 {
+ compAny := make(map[string]any, len(completionRatioMap))
+ for k, v := range completionRatioMap {
+ compAny[k] = v
+ }
+ converted["completion_ratio"] = compAny
+ }
- if len(modelPriceMap) > 0 {
- priceAny := make(map[string]any, len(modelPriceMap))
- for k, v := range modelPriceMap {
- priceAny[k] = v
- }
- converted["model_price"] = priceAny
- }
+ if len(modelPriceMap) > 0 {
+ priceAny := make(map[string]any, len(modelPriceMap))
+ for k, v := range modelPriceMap {
+ priceAny[k] = v
+ }
+ converted["model_price"] = priceAny
+ }
- ch <- upstreamResult{Name: uniqueName, Data: converted}
- }(chn)
- }
+ ch <- upstreamResult{Name: uniqueName, Data: converted}
+ }(chn)
+ }
- wg.Wait()
- close(ch)
+ wg.Wait()
+ close(ch)
- localData := ratio_setting.GetExposedData()
+ localData := ratio_setting.GetExposedData()
- var testResults []dto.TestResult
- var successfulChannels []struct {
- name string
- data map[string]any
- }
+ var testResults []dto.TestResult
+ var successfulChannels []struct {
+ name string
+ data map[string]any
+ }
- for r := range ch {
- if r.Err != "" {
- testResults = append(testResults, dto.TestResult{
- Name: r.Name,
- Status: "error",
- Error: r.Err,
- })
- } else {
- testResults = append(testResults, dto.TestResult{
- Name: r.Name,
- Status: "success",
- })
- successfulChannels = append(successfulChannels, struct {
- name string
- data map[string]any
- }{name: r.Name, data: r.Data})
- }
- }
+ for r := range ch {
+ if r.Err != "" {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "error",
+ Error: r.Err,
+ })
+ } else {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "success",
+ })
+ successfulChannels = append(successfulChannels, struct {
+ name string
+ data map[string]any
+ }{name: r.Name, data: r.Data})
+ }
+ }
- differences := buildDifferences(localData, successfulChannels)
+ differences := buildDifferences(localData, successfulChannels)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": gin.H{
- "differences": differences,
- "test_results": testResults,
- },
- })
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": gin.H{
+ "differences": differences,
+ "test_results": testResults,
+ },
+ })
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
- name string
- data map[string]any
+ name string
+ data map[string]any
}) map[string]map[string]dto.DifferenceItem {
- differences := make(map[string]map[string]dto.DifferenceItem)
+ differences := make(map[string]map[string]dto.DifferenceItem)
- allModels := make(map[string]struct{})
-
- for _, ratioType := range ratioTypes {
- if localRatioAny, ok := localData[ratioType]; ok {
- if localRatio, ok := localRatioAny.(map[string]float64); ok {
- for modelName := range localRatio {
- allModels[modelName] = struct{}{}
- }
- }
- }
- }
-
- for _, channel := range successfulChannels {
- for _, ratioType := range ratioTypes {
- if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
- for modelName := range upstreamRatio {
- allModels[modelName] = struct{}{}
- }
- }
- }
- }
+ allModels := make(map[string]struct{})
- confidenceMap := make(map[string]map[string]bool)
-
- // 预处理阶段:检查pricing接口的可信度
- for _, channel := range successfulChannels {
- confidenceMap[channel.name] = make(map[string]bool)
-
- modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
- completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
-
- if hasModelRatio && hasCompletionRatio {
- // 遍历所有模型,检查是否满足不可信条件
- for modelName := range allModels {
- // 默认为可信
- confidenceMap[channel.name][modelName] = true
-
- // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
- if modelRatioVal, ok := modelRatios[modelName]; ok {
- if completionRatioVal, ok := completionRatios[modelName]; ok {
- // 转换为float64进行比较
- if modelRatioFloat, ok := modelRatioVal.(float64); ok {
- if completionRatioFloat, ok := completionRatioVal.(float64); ok {
- if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
- confidenceMap[channel.name][modelName] = false
- }
- }
- }
- }
- }
- }
- } else {
- // 如果不是从pricing接口获取的数据,则全部标记为可信
- for modelName := range allModels {
- confidenceMap[channel.name][modelName] = true
- }
- }
- }
+ for _, ratioType := range ratioTypes {
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ for modelName := range localRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
- for modelName := range allModels {
- for _, ratioType := range ratioTypes {
- var localValue interface{} = nil
- if localRatioAny, ok := localData[ratioType]; ok {
- if localRatio, ok := localRatioAny.(map[string]float64); ok {
- if val, exists := localRatio[modelName]; exists {
- localValue = val
- }
- }
- }
+ for _, channel := range successfulChannels {
+ for _, ratioType := range ratioTypes {
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ for modelName := range upstreamRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
- upstreamValues := make(map[string]interface{})
- confidenceValues := make(map[string]bool)
- hasUpstreamValue := false
- hasDifference := false
+ confidenceMap := make(map[string]map[string]bool)
- for _, channel := range successfulChannels {
- var upstreamValue interface{} = nil
-
- if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
- if val, exists := upstreamRatio[modelName]; exists {
- upstreamValue = val
- hasUpstreamValue = true
-
- if localValue != nil && localValue != val {
- hasDifference = true
- } else if localValue == val {
- upstreamValue = "same"
- }
- }
- }
- if upstreamValue == nil && localValue == nil {
- upstreamValue = "same"
- }
-
- if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
- hasDifference = true
- }
-
- upstreamValues[channel.name] = upstreamValue
-
- confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
- }
+ // 预处理阶段:检查pricing接口的可信度
+ for _, channel := range successfulChannels {
+ confidenceMap[channel.name] = make(map[string]bool)
- shouldInclude := false
-
- if localValue != nil {
- if hasDifference {
- shouldInclude = true
- }
- } else {
- if hasUpstreamValue {
- shouldInclude = true
- }
- }
+ modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
+ completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
- if shouldInclude {
- if differences[modelName] == nil {
- differences[modelName] = make(map[string]dto.DifferenceItem)
- }
- differences[modelName][ratioType] = dto.DifferenceItem{
- Current: localValue,
- Upstreams: upstreamValues,
- Confidence: confidenceValues,
- }
- }
- }
- }
+ if hasModelRatio && hasCompletionRatio {
+ // 遍历所有模型,检查是否满足不可信条件
+ for modelName := range allModels {
+ // 默认为可信
+ confidenceMap[channel.name][modelName] = true
- channelHasDiff := make(map[string]bool)
- for _, ratioMap := range differences {
- for _, item := range ratioMap {
- for chName, val := range item.Upstreams {
- if val != nil && val != "same" {
- channelHasDiff[chName] = true
- }
- }
- }
- }
+ // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
+ if modelRatioVal, ok := modelRatios[modelName]; ok {
+ if completionRatioVal, ok := completionRatios[modelName]; ok {
+ // 转换为float64进行比较
+ if modelRatioFloat, ok := modelRatioVal.(float64); ok {
+ if completionRatioFloat, ok := completionRatioVal.(float64); ok {
+ if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
+ confidenceMap[channel.name][modelName] = false
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ // 如果不是从pricing接口获取的数据,则全部标记为可信
+ for modelName := range allModels {
+ confidenceMap[channel.name][modelName] = true
+ }
+ }
+ }
- for modelName, ratioMap := range differences {
- for ratioType, item := range ratioMap {
- for chName := range item.Upstreams {
- if !channelHasDiff[chName] {
- delete(item.Upstreams, chName)
- delete(item.Confidence, chName)
- }
- }
+ for modelName := range allModels {
+ for _, ratioType := range ratioTypes {
+ var localValue interface{} = nil
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ if val, exists := localRatio[modelName]; exists {
+ localValue = val
+ }
+ }
+ }
- allSame := true
- for _, v := range item.Upstreams {
- if v != "same" {
- allSame = false
- break
- }
- }
- if len(item.Upstreams) == 0 || allSame {
- delete(ratioMap, ratioType)
- } else {
- differences[modelName][ratioType] = item
- }
- }
+ upstreamValues := make(map[string]interface{})
+ confidenceValues := make(map[string]bool)
+ hasUpstreamValue := false
+ hasDifference := false
- if len(ratioMap) == 0 {
- delete(differences, modelName)
- }
- }
+ for _, channel := range successfulChannels {
+ var upstreamValue interface{} = nil
- return differences
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ if val, exists := upstreamRatio[modelName]; exists {
+ upstreamValue = val
+ hasUpstreamValue = true
+
+ if localValue != nil && localValue != val {
+ hasDifference = true
+ } else if localValue == val {
+ upstreamValue = "same"
+ }
+ }
+ }
+ if upstreamValue == nil && localValue == nil {
+ upstreamValue = "same"
+ }
+
+ if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+ hasDifference = true
+ }
+
+ upstreamValues[channel.name] = upstreamValue
+
+ confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
+ }
+
+ shouldInclude := false
+
+ if localValue != nil {
+ if hasDifference {
+ shouldInclude = true
+ }
+ } else {
+ if hasUpstreamValue {
+ shouldInclude = true
+ }
+ }
+
+ if shouldInclude {
+ if differences[modelName] == nil {
+ differences[modelName] = make(map[string]dto.DifferenceItem)
+ }
+ differences[modelName][ratioType] = dto.DifferenceItem{
+ Current: localValue,
+ Upstreams: upstreamValues,
+ Confidence: confidenceValues,
+ }
+ }
+ }
+ }
+
+ channelHasDiff := make(map[string]bool)
+ for _, ratioMap := range differences {
+ for _, item := range ratioMap {
+ for chName, val := range item.Upstreams {
+ if val != nil && val != "same" {
+ channelHasDiff[chName] = true
+ }
+ }
+ }
+ }
+
+ for modelName, ratioMap := range differences {
+ for ratioType, item := range ratioMap {
+ for chName := range item.Upstreams {
+ if !channelHasDiff[chName] {
+ delete(item.Upstreams, chName)
+ delete(item.Confidence, chName)
+ }
+ }
+
+ allSame := true
+ for _, v := range item.Upstreams {
+ if v != "same" {
+ allSame = false
+ break
+ }
+ }
+ if len(item.Upstreams) == 0 || allSame {
+ delete(ratioMap, ratioType)
+ } else {
+ differences[modelName][ratioType] = item
+ }
+ }
+
+ if len(ratioMap) == 0 {
+ delete(differences, modelName)
+ }
+ }
+
+ return differences
}
func GetSyncableChannels(c *gin.Context) {
- channels, err := model.GetAllChannels(0, 0, true, false)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
+ channels, err := model.GetAllChannels(0, 0, true, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
- var syncableChannels []dto.SyncableChannel
- for _, channel := range channels {
- if channel.GetBaseURL() != "" {
- syncableChannels = append(syncableChannels, dto.SyncableChannel{
- ID: channel.Id,
- Name: channel.Name,
- BaseURL: channel.GetBaseURL(),
- Status: channel.Status,
- })
- }
- }
+ var syncableChannels []dto.SyncableChannel
+ for _, channel := range channels {
+ if channel.GetBaseURL() != "" {
+ syncableChannels = append(syncableChannels, dto.SyncableChannel{
+ ID: channel.Id,
+ Name: channel.Name,
+ BaseURL: channel.GetBaseURL(),
+ Status: channel.Status,
+ })
+ }
+ }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": syncableChannels,
- })
-}
\ No newline at end of file
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": syncableChannels,
+ })
+}
diff --git a/controller/relay.go b/controller/relay.go
index d235f550..583ac036 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,21 +2,22 @@ package controller
import (
"bytes"
- "errors"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
- constant2 "one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/middleware"
"one-api/model"
"one-api/relay"
+ relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
+ "one-api/setting"
"one-api/types"
"strings"
@@ -24,81 +25,196 @@ import (
"github.com/gorilla/websocket"
)
-func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
+func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
var err *types.NewAPIError
- switch relayMode {
+ switch info.RelayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
- err = relay.ImageHelper(c)
+ err = relay.ImageHelper(c, info)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
- err = relay.AudioHelper(c)
+ err = relay.AudioHelper(c, info)
case relayconstant.RelayModeRerank:
- err = relay.RerankHelper(c, relayMode)
+ err = relay.RerankHelper(c, info)
case relayconstant.RelayModeEmbeddings:
- err = relay.EmbeddingHelper(c)
+ err = relay.EmbeddingHelper(c, info)
case relayconstant.RelayModeResponses:
- err = relay.ResponsesHelper(c)
- case relayconstant.RelayModeGemini:
- if strings.Contains(c.Request.URL.Path, "embed") {
- err = relay.GeminiEmbeddingHandler(c)
- } else {
- err = relay.GeminiHelper(c)
- }
+ err = relay.ResponsesHelper(c, info)
default:
- err = relay.TextHelper(c)
+ err = relay.TextHelper(c, info)
}
-
- if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
- // 保存错误日志到mysql中
- userId := c.GetInt("id")
- tokenName := c.GetString("token_name")
- modelName := c.GetString("original_model")
- tokenId := c.GetInt("token_id")
- userGroup := c.GetString("group")
- channelId := c.GetInt("channel_id")
- other := make(map[string]interface{})
- other["error_type"] = err.GetErrorType()
- other["error_code"] = err.GetErrorCode()
- other["status_code"] = err.StatusCode
- other["channel_id"] = channelId
- other["channel_name"] = c.GetString("channel_name")
- other["channel_type"] = c.GetInt("channel_type")
- adminInfo := make(map[string]interface{})
- adminInfo["use_channel"] = c.GetStringSlice("use_channel")
- isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
- if isMultiKey {
- adminInfo["is_multi_key"] = true
- adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
- }
- other["admin_info"] = adminInfo
- model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
- }
-
return err
}
-func Relay(c *gin.Context) {
- relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
+ var err *types.NewAPIError
+ if strings.Contains(c.Request.URL.Path, "embed") {
+ err = relay.GeminiEmbeddingHandler(c, info)
+ } else {
+ err = relay.GeminiHelper(c, info)
+ }
+ return err
+}
+
+func Relay(c *gin.Context, relayFormat types.RelayFormat) {
+
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
- var newAPIError *types.NewAPIError
+
+ var (
+ newAPIError *types.NewAPIError
+ ws *websocket.Conn
+ )
+
+ if relayFormat == types.RelayFormatOpenAIRealtime {
+ var err error
+ ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
+ if err != nil {
+ helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
+ return
+ }
+ defer ws.Close()
+ }
+
+ defer func() {
+ if newAPIError != nil {
+ newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+ switch relayFormat {
+ case types.RelayFormatOpenAIRealtime:
+ helper.WssError(c, ws, newAPIError.ToOpenAIError())
+ case types.RelayFormatClaude:
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "type": "error",
+ "error": newAPIError.ToClaudeError(),
+ })
+ default:
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "error": newAPIError.ToOpenAIError(),
+ })
+ }
+ }
+ }()
+
+ request, err := helper.GetAndValidateRequest(c, relayFormat)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+ return
+ }
+
+ //includeUsage := true
+ //// 判断用户是否需要返回使用情况
+ //if textRequest.StreamOptions != nil {
+ // includeUsage = textRequest.StreamOptions.IncludeUsage
+ //}
+ //
+ //// 如果不支持StreamOptions,将StreamOptions设置为nil
+ //if !relayInfo.SupportStreamOptions || !textRequest.Stream {
+ // textRequest.StreamOptions = nil
+ //} else {
+ // // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+ // if constant.ForceStreamOption {
+ // textRequest.StreamOptions = &dto.StreamOptions{
+ // IncludeUsage: true,
+ // }
+ // }
+ //}
+ //
+ //relayInfo.ShouldIncludeUsage = includeUsage
+
+ relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
+ return
+ }
+
+ meta := request.GetTokenCountMeta()
+
+ if setting.ShouldCheckPromptSensitive() {
+ words, err := service.CheckSensitiveText(meta.CombineText)
+ if err != nil {
+ logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
+ newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+ return
+ }
+ }
+
+ tokens, err := service.CountRequestToken(c, meta, relayInfo)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
+ return
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
+ return
+ }
+
+ preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newApiErr != nil {
+ return
+ }
+
+ defer func() {
+ if newApiErr != nil {
+ service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+ }
+ }()
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
newAPIError = err
break
}
- newAPIError = relayRequest(c, relayMode, channel)
+ addUsedChannel(c, channel.Id)
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+
+ switch relayFormat {
+ case types.RelayFormatOpenAIRealtime:
+ newAPIError = relay.WssHelper(c, ws)
+ case types.RelayFormatClaude:
+ newAPIError = relay.ClaudeHelper(c, relayInfo)
+ case types.RelayFormatGemini:
+ newAPIError = geminiRelayHandler(c, relayInfo)
+ default:
+ newAPIError = relayHandler(c, relayInfo)
+ }
if newAPIError == nil {
- return // 成功处理请求,直接返回
+ return
+ } else {
+ if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
+ // 保存错误日志到mysql中
+ userId := c.GetInt("id")
+ tokenName := c.GetString("token_name")
+ modelName := c.GetString("original_model")
+ tokenId := c.GetInt("token_id")
+ userGroup := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ other := make(map[string]interface{})
+ other["error_type"] = newAPIError.GetErrorType()
+ other["error_code"] = newAPIError.GetErrorCode()
+ other["status_code"] = newAPIError.StatusCode
+ other["channel_id"] = channelId
+ other["channel_name"] = c.GetString("channel_name")
+ other["channel_type"] = c.GetInt("channel_type")
+ adminInfo := make(map[string]interface{})
+ adminInfo["use_channel"] = c.GetStringSlice("use_channel")
+ isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
+ if isMultiKey {
+ adminInfo["is_multi_key"] = true
+ adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
+ }
+ other["admin_info"] = adminInfo
+ model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
+ }
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
@@ -107,21 +223,11 @@ func Relay(c *gin.Context) {
break
}
}
+
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if newAPIError != nil {
- //if newAPIError.StatusCode == http.StatusTooManyRequests {
- // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
- // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
- //}
- newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
- c.JSON(newAPIError.StatusCode, gin.H{
- "error": newAPIError.ToOpenAIError(),
- })
+ logger.LogInfo(c, retryLogStr)
}
}
@@ -132,122 +238,6 @@ var upgrader = websocket.Upgrader{
},
}
-func WssRelay(c *gin.Context) {
- // 将 HTTP 连接升级为 WebSocket 连接
-
- ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
- defer ws.Close()
-
- if err != nil {
- helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
- return
- }
-
- relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
- requestId := c.GetString(common.RequestIdKey)
- group := c.GetString("group")
- //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
- originalModel := c.GetString("original_model")
- var newAPIError *types.NewAPIError
-
- for i := 0; i <= common.RetryTimes; i++ {
- channel, err := getChannel(c, group, originalModel, i)
- if err != nil {
- common.LogError(c, err.Error())
- newAPIError = err
- break
- }
-
- newAPIError = wssRequest(c, ws, relayMode, channel)
-
- if newAPIError == nil {
- return // 成功处理请求,直接返回
- }
-
- go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
-
- if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
- break
- }
- }
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if newAPIError != nil {
- //if newAPIError.StatusCode == http.StatusTooManyRequests {
- // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
- //}
- newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
- helper.WssError(c, ws, newAPIError.ToOpenAIError())
- }
-}
-
-func RelayClaude(c *gin.Context) {
- //relayMode := constant.Path2RelayMode(c.Request.URL.Path)
- requestId := c.GetString(common.RequestIdKey)
- group := c.GetString("group")
- originalModel := c.GetString("original_model")
- var newAPIError *types.NewAPIError
-
- for i := 0; i <= common.RetryTimes; i++ {
- channel, err := getChannel(c, group, originalModel, i)
- if err != nil {
- common.LogError(c, err.Error())
- newAPIError = err
- break
- }
-
- newAPIError = claudeRequest(c, channel)
-
- if newAPIError == nil {
- return // 成功处理请求,直接返回
- }
-
- go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
-
- if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
- break
- }
- }
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if newAPIError != nil {
- newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
- c.JSON(newAPIError.StatusCode, gin.H{
- "type": "error",
- "error": newAPIError.ToClaudeError(),
- })
- }
-}
-
-func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relayHandler(c, relayMode)
-}
-
-func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relay.WssHelper(c, ws)
-}
-
-func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relay.ClaudeHelper(c)
-}
-
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
@@ -270,10 +260,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
}
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
- return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+ return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
- return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+ return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
@@ -327,7 +317,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
- common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+ logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error())
}
@@ -362,7 +352,7 @@ func RelayMidjourney(c *gin.Context) {
"code": err.Code,
})
channelId := c.GetInt("channel_id")
- common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
+ logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}
@@ -404,7 +394,7 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, newAPIError := getChannel(c, group, originalModel, i)
if newAPIError != nil {
- common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+ logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break
}
@@ -412,7 +402,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
- common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+ logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, _ := common.GetRequestBody(c)
@@ -422,7 +412,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
+ logger.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
diff --git a/controller/task.go b/controller/task.go
index 5fbdb424..a5b28ae2 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/relay"
"sort"
@@ -25,7 +26,7 @@ func UpdateTaskBulk() {
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
- common.SysLog("任务进度轮询开始")
+ logger.SysLog("任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
"progress": "100%",
})
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else {
- common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+ logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
@@ -65,7 +66,7 @@ func UpdateTaskBulk() {
UpdateTaskByPlatform(platform, taskChannelM, taskM)
}
- common.SysLog("任务进度轮询完成")
+ logger.SysLog("任务进度轮询完成")
}
}
@@ -77,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
default:
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
+ logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
}
}
}
@@ -86,27 +87,27 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+ logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
}
}
return nil
}
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
- common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+ logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
err = model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
- common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+ logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
}
return err
}
@@ -118,27 +119,27 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"ids": taskIds,
})
if err != nil {
- common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+ logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
return err
}
if resp.StatusCode != http.StatusOK {
- common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+ logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
return err
}
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err
}
if !responseItems.IsSuccess() {
- common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+ logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
return err
}
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
- common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+ logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
//err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
+ logger.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
- logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+ logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
err = task.Update()
if err != nil {
- common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+ logger.SysError("UpdateMidjourneyTask task error: " + err.Error())
}
}
return nil
diff --git a/controller/task_video.go b/controller/task_video.go
index 914bf6e6..dca42955 100644
--- a/controller/task_video.go
+++ b/controller/task_video.go
@@ -5,9 +5,9 @@ import (
"encoding/json"
"fmt"
"io"
- "one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
@@ -18,14 +18,14 @@ import (
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
- common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
@@ -37,7 +37,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
"progress": "100%",
})
if errUpdate != nil {
- common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
@@ -47,7 +47,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
- common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
@@ -61,7 +61,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task := taskM[taskId]
if task == nil {
- common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
@@ -124,13 +124,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.FinishTime = now
}
task.FailReason = taskResult.Reason
- common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
- common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+ logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
- logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+ logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
default:
@@ -140,7 +140,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.Progress = taskResult.Progress
}
if err := task.Update(); err != nil {
- common.SysError("UpdateVideoTask task error: " + err.Error())
+ logger.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
diff --git a/controller/token.go b/controller/token.go
index 62eb5474..db575fec 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -3,6 +3,7 @@ package controller
import (
"net/http"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"strconv"
@@ -102,7 +103,7 @@ func AddToken(c *gin.Context) {
"success": false,
"message": "生成令牌失败",
})
- common.SysError("failed to generate token key: " + err.Error())
+ logger.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{
diff --git a/controller/topup.go b/controller/topup.go
index 827dda39..3f3c8623 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -5,6 +5,7 @@ import (
"log"
"net/url"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"one-api/service"
"one-api/setting"
@@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) {
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
- model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
+ model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
diff --git a/controller/twofa.go b/controller/twofa.go
index 9f48eed8..0ab66029 100644
--- a/controller/twofa.go
+++ b/controller/twofa.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"strconv"
@@ -70,7 +71,7 @@ func Setup2FA(c *gin.Context) {
"success": false,
"message": "生成2FA密钥失败",
})
- common.SysError("生成TOTP密钥失败: " + err.Error())
+ logger.SysError("生成TOTP密钥失败: " + err.Error())
return
}
@@ -81,7 +82,7 @@ func Setup2FA(c *gin.Context) {
"success": false,
"message": "生成备用码失败",
})
- common.SysError("生成备用码失败: " + err.Error())
+ logger.SysError("生成备用码失败: " + err.Error())
return
}
@@ -115,7 +116,7 @@ func Setup2FA(c *gin.Context) {
"success": false,
"message": "保存备用码失败",
})
- common.SysError("保存备用码失败: " + err.Error())
+ logger.SysError("保存备用码失败: " + err.Error())
return
}
@@ -294,7 +295,7 @@ func Get2FAStatus(c *gin.Context) {
// 获取剩余备用码数量
backupCount, err := model.GetUnusedBackupCodeCount(userId)
if err != nil {
- common.SysError("获取备用码数量失败: " + err.Error())
+ logger.SysError("获取备用码数量失败: " + err.Error())
} else {
status["backup_codes_remaining"] = backupCount
}
@@ -368,7 +369,7 @@ func RegenerateBackupCodes(c *gin.Context) {
"success": false,
"message": "生成备用码失败",
})
- common.SysError("生成备用码失败: " + err.Error())
+ logger.SysError("生成备用码失败: " + err.Error())
return
}
@@ -378,7 +379,7 @@ func RegenerateBackupCodes(c *gin.Context) {
"success": false,
"message": "保存备用码失败",
})
- common.SysError("保存备用码失败: " + err.Error())
+ logger.SysError("保存备用码失败: " + err.Error())
return
}
diff --git a/controller/user.go b/controller/user.go
index 29cf83e1..8ce44fa6 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -7,6 +7,7 @@ import (
"net/url"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/setting"
"strconv"
@@ -192,7 +193,7 @@ func Register(c *gin.Context) {
"success": false,
"message": "数据库错误,请稍后重试",
})
- common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
+ logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return
}
if exist {
@@ -235,7 +236,7 @@ func Register(c *gin.Context) {
"success": false,
"message": "生成默认令牌失败",
})
- common.SysError("failed to generate token key: " + err.Error())
+ logger.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
@@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
"success": false,
"message": "生成失败",
})
- common.SysError("failed to generate key: " + err.Error())
+ logger.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
@@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) {
return
}
if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/dto/audio.go b/dto/audio.go
index c36b3da5..81872c69 100644
--- a/dto/audio.go
+++ b/dto/audio.go
@@ -1,5 +1,11 @@
package dto
+import (
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
@@ -8,6 +14,18 @@ type AudioRequest struct {
ResponseFormat string `json:"response_format,omitempty"`
}
+func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ meta := &types.TokenCountMeta{
+ CombineText: r.Input,
+ TokenType: types.TokenTypeTextNumber,
+ }
+ return meta
+}
+
+func (r *AudioRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
type AudioResponse struct {
Text string `json:"text"`
}
diff --git a/dto/claude.go b/dto/claude.go
index 58a09217..2b3adf19 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -5,6 +5,9 @@ import (
"fmt"
"one-api/common"
"one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
)
type ClaudeMetadata struct {
@@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
- jsonContent, _ := json.Marshal(c)
+ jsonContent, _ := common.Marshal(c)
return string(jsonContent)
}
@@ -199,6 +202,129 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"`
}
+func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var tokenCountMeta = types.TokenCountMeta{
+ TokenType: types.TokenTypeTextNumber,
+ MaxTokens: int(c.MaxTokens),
+ }
+
+ var texts = make([]string, 0)
+ var fileMeta = make([]*types.FileMeta, 0)
+
+ // system
+ if c.System != nil {
+ if c.IsStringSystem() {
+ sys := c.GetStringSystem()
+ if sys != "" {
+ texts = append(texts, sys)
+ }
+ } else {
+ systemMedia := c.ParseSystem()
+ for _, media := range systemMedia {
+ switch media.Type {
+ case "text":
+ texts = append(texts, media.GetText())
+ case "image":
+ if media.Source != nil {
+ data := media.Source.Url
+ if data == "" {
+ data = common.Interface2String(media.Source.Data)
+ }
+ if data != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // messages
+ for _, message := range c.Messages {
+ tokenCountMeta.MessagesCount++
+ texts = append(texts, message.Role)
+ if message.IsStringContent() {
+ content := message.GetStringContent()
+ if content != "" {
+ texts = append(texts, content)
+ }
+ continue
+ }
+
+ content, _ := message.ParseContent()
+ for _, media := range content {
+ switch media.Type {
+ case "text":
+ texts = append(texts, media.GetText())
+ case "image":
+ if media.Source != nil {
+ data := media.Source.Url
+ if data == "" {
+ data = common.Interface2String(media.Source.Data)
+ }
+ if data != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+ }
+ }
+ case "tool_use":
+ if media.Name != "" {
+ texts = append(texts, media.Name)
+ }
+ if media.Input != nil {
+ b, _ := common.Marshal(media.Input)
+ texts = append(texts, string(b))
+ }
+ case "tool_result":
+ if media.Content != nil {
+ b, _ := common.Marshal(media.Content)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ }
+
+ // tools
+ if c.Tools != nil {
+ tools := c.GetTools()
+ normalTools, webSearchTools := ProcessTools(tools)
+ if normalTools != nil {
+ for _, t := range normalTools {
+ tokenCountMeta.ToolsCount++
+ if t.Name != "" {
+ texts = append(texts, t.Name)
+ }
+ if t.Description != "" {
+ texts = append(texts, t.Description)
+ }
+ if t.InputSchema != nil {
+ b, _ := common.Marshal(t.InputSchema)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ if webSearchTools != nil {
+ for _, t := range webSearchTools {
+ tokenCountMeta.ToolsCount++
+ if t.Name != "" {
+ texts = append(texts, t.Name)
+ }
+ if t.UserLocation != nil {
+ b, _ := common.Marshal(t.UserLocation)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ }
+
+ tokenCountMeta.CombineText = strings.Join(texts, "\n")
+ tokenCountMeta.Files = fileMeta
+ return &tokenCountMeta
+}
+
+func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
+ return claudeRequest.Stream
+}
+
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
for _, message := range c.Messages {
content, _ := message.ParseContent()
diff --git a/dto/embedding.go b/dto/embedding.go
index 9d722292..fff37776 100644
--- a/dto/embedding.go
+++ b/dto/embedding.go
@@ -1,5 +1,12 @@
package dto
+import (
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -24,9 +31,26 @@ type EmbeddingRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
-func (r EmbeddingRequest) ParseInput() []string {
+func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var texts = make([]string, 0)
+
+ inputs := r.ParseInput()
+ for _, input := range inputs {
+ texts = append(texts, input)
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ }
+}
+
+func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (r *EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
- return nil
+ return make([]string, 0)
}
var input []string
switch r.Input.(type) {
diff --git a/dto/gemini.go b/dto/gemini.go
index 6cb3e17a..b327de62 100644
--- a/dto/gemini.go
+++ b/dto/gemini.go
@@ -2,7 +2,10 @@ package dto
import (
"encoding/json"
+ "github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/logger"
+ "one-api/types"
"strings"
)
@@ -14,19 +17,75 @@ type GeminiChatRequest struct {
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
+func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var files []*types.FileMeta = make([]*types.FileMeta, 0)
+
+ var maxTokens int
+
+ if r.GenerationConfig.MaxOutputTokens > 0 {
+ maxTokens = int(r.GenerationConfig.MaxOutputTokens)
+ }
+
+ var inputTexts []string
+ for _, content := range r.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ if part.InlineData != nil && part.InlineData.Data != "" {
+ if strings.HasPrefix(part.InlineData.MimeType, "image/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeImage,
+ Data: part.InlineData.Data,
+ })
+ } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeAudio,
+ Data: part.InlineData.Data,
+ })
+ } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeVideo,
+ Data: part.InlineData.Data,
+ })
+ } else {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeFile,
+ Data: part.InlineData.Data,
+ })
+ }
+ }
+ }
+ }
+
+ inputText := strings.Join(inputTexts, "\n")
+ return &types.TokenCountMeta{
+ CombineText: inputText,
+ Files: files,
+ MaxTokens: maxTokens,
+ }
+}
+
+func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
+ if c.Query("alt") == "sse" {
+ return true
+ }
+ return false
+}
+
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
var tools []GeminiChatTool
if strings.HasSuffix(string(r.Tools), "[") {
// is array
if err := common.Unmarshal(r.Tools, &tools); err != nil {
- common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
+ logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
return nil
}
} else if strings.HasPrefix(string(r.Tools), "{") {
// is object
singleTool := GeminiChatTool{}
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
- common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
+ logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
return nil
}
tools = []GeminiChatTool{singleTool}
@@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
// Marshal the tools to JSON
data, err := common.Marshal(tools)
if err != nil {
- common.LogError(nil, "error_marshalling_tools: "+err.Error())
+ logger.LogError(nil, "error_marshalling_tools: "+err.Error())
return
}
r.Tools = data
diff --git a/dto/dalle.go b/dto/openai_image.go
similarity index 51%
rename from dto/dalle.go
rename to dto/openai_image.go
index ce2f6361..7431935b 100644
--- a/dto/dalle.go
+++ b/dto/openai_image.go
@@ -1,11 +1,17 @@
package dto
-import "encoding/json"
+import (
+ "encoding/json"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
- N int `json:"n,omitempty"`
+ N uint `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
@@ -18,6 +24,42 @@ type ImageRequest struct {
Watermark *bool `json:"watermark,omitempty"`
}
+func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var sizeRatio = 1.0
+ var qualityRatio = 1.0
+
+ if strings.HasPrefix(i.Model, "dall-e") {
+ // Size
+ if i.Size == "256x256" {
+ sizeRatio = 0.4
+ } else if i.Size == "512x512" {
+ sizeRatio = 0.45
+ } else if i.Size == "1024x1024" {
+ sizeRatio = 1
+ } else if i.Size == "1024x1792" || i.Size == "1792x1024" {
+ sizeRatio = 2
+ }
+
+ if i.Model == "dall-e-3" && i.Quality == "hd" {
+ qualityRatio = 2.0
+ if i.Size == "1024x1792" || i.Size == "1792x1024" {
+ qualityRatio = 1.5
+ }
+ }
+ }
+
+ // not support token count for dalle
+ return &types.TokenCountMeta{
+ CombineText: i.Prompt,
+ MaxTokens: 1584,
+ ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
+ }
+}
+
+func (i *ImageRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 7a23ca5c..0c01c503 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -2,8 +2,12 @@ package dto
import (
"encoding/json"
+ "fmt"
"one-api/common"
+ "one-api/types"
"strings"
+
+ "github.com/gin-gonic/gin"
)
type ResponseFormat struct {
@@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct {
Extra map[string]json.RawMessage `json:"-"`
}
+func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var tokenCountMeta types.TokenCountMeta
+ var texts = make([]string, 0)
+ var fileMeta = make([]*types.FileMeta, 0)
+
+ if r.Prompt != nil {
+ switch v := r.Prompt.(type) {
+ case string:
+ texts = append(texts, v)
+ case []any:
+ for _, item := range v {
+ if str, ok := item.(string); ok {
+ texts = append(texts, str)
+ }
+ }
+ default:
+ texts = append(texts, fmt.Sprintf("%v", r.Prompt))
+ }
+ }
+
+ if r.Input != nil {
+ inputs := r.ParseInput()
+ texts = append(texts, inputs...)
+ }
+
+ if r.MaxCompletionTokens > r.MaxTokens {
+ tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
+ } else {
+ tokenCountMeta.MaxTokens = int(r.MaxTokens)
+ }
+
+ for _, message := range r.Messages {
+ tokenCountMeta.MessagesCount++
+ texts = append(texts, message.Role)
+ if message.Content != nil {
+ if message.Name != nil {
+ tokenCountMeta.NameCount++
+ texts = append(texts, *message.Name)
+ }
+ arrayContent := message.ParseContent()
+ for _, m := range arrayContent {
+ if m.Type == ContentTypeImageURL {
+ imageUrl := m.GetImageMedia()
+ if imageUrl != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeImage,
+ }
+ meta.Data = imageUrl.Url
+ meta.Detail = imageUrl.Detail
+ fileMeta = append(fileMeta, meta)
+ }
+ } else if m.Type == ContentTypeInputAudio {
+ inputAudio := m.GetInputAudio()
+ if inputAudio != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeAudio,
+ }
+ meta.Data = inputAudio.Data
+ fileMeta = append(fileMeta, meta)
+ }
+ } else if m.Type == ContentTypeFile {
+ file := m.GetFile()
+ if file != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeFile,
+ }
+ meta.Data = file.FileData
+ fileMeta = append(fileMeta, meta)
+ }
+ } else if m.Type == ContentTypeVideoUrl {
+ videoUrl := m.GetVideoUrl()
+ if videoUrl != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeVideo,
+ }
+ meta.Data = videoUrl.Url
+ fileMeta = append(fileMeta, meta)
+ }
+ } else {
+ texts = append(texts, m.Text)
+ }
+ }
+ }
+ }
+
+ if r.Tools != nil {
+ openaiTools := r.Tools
+ for _, tool := range openaiTools {
+ tokenCountMeta.ToolsCount++
+ texts = append(texts, tool.Function.Name)
+ if tool.Function.Description != "" {
+ texts = append(texts, tool.Function.Description)
+ }
+ if tool.Function.Parameters != nil {
+ texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
+ }
+ }
+ //toolTokens := CountTokenInput(countStr, request.Model)
+ //tkm += 8
+ //tkm += toolTokens
+ }
+ tokenCountMeta.CombineText = strings.Join(texts, "\n")
+ tokenCountMeta.Files = fileMeta
+ return &tokenCountMeta
+}
+
+func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
+ return r.Stream
+}
+
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
data, _ := common.Marshal(r)
@@ -202,10 +316,25 @@ func (m *MediaContent) GetFile() *MessageFile {
return nil
}
+func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
+ if m.VideoUrl != nil {
+ if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
+ return m.VideoUrl.(*MessageVideoUrl)
+ }
+ if itemMap, ok := m.VideoUrl.(map[string]any); ok {
+ out := &MessageVideoUrl{
+ Url: common.Interface2String(itemMap["url"]),
+ }
+ return out
+ }
+ }
+ return nil
+}
+
type MessageImageUrl struct {
- Url string `json:"url"`
- Detail string `json:"detail"`
- MimeType string
+ Url string `json:"url"`
+ Detail string `json:"detail"`
+ //MimeType string
}
func (m *MessageImageUrl) IsRemoteImage() bool {
@@ -233,6 +362,7 @@ const (
ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
+ //ContentTypeAudioUrl = "audio_url"
)
func (m *Message) GetPrefix() bool {
@@ -623,7 +753,7 @@ type WebSearchOptions struct {
// https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct {
Model string `json:"model"`
- Input json.RawMessage `json:"input,omitempty"`
+ Input any `json:"input,omitempty"`
Include json.RawMessage `json:"include,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
@@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct {
Prompt json.RawMessage `json:"prompt,omitempty"`
}
+func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var fileMeta = make([]*types.FileMeta, 0)
+ var texts = make([]string, 0)
+
+ if r.Input != nil {
+ inputs := r.ParseInput()
+ for _, input := range inputs {
+ if input.Type == "input_image" {
+ fileMeta = append(fileMeta, &types.FileMeta{
+ FileType: types.FileTypeImage,
+ Data: input.ImageUrl,
+ Detail: input.Detail,
+ })
+ } else if input.Type == "input_file" {
+ fileMeta = append(fileMeta, &types.FileMeta{
+ FileType: types.FileTypeFile,
+ Data: input.FileUrl,
+ })
+ } else {
+ texts = append(texts, input.Text)
+ }
+ }
+ }
+
+ if len(r.Instructions) > 0 {
+ texts = append(texts, string(r.Instructions))
+ }
+
+ if len(r.Metadata) > 0 {
+ texts = append(texts, string(r.Metadata))
+ }
+
+ if len(r.Text) > 0 {
+ texts = append(texts, string(r.Text))
+ }
+
+ if len(r.ToolChoice) > 0 {
+ texts = append(texts, string(r.ToolChoice))
+ }
+
+ if len(r.Prompt) > 0 {
+ texts = append(texts, string(r.Prompt))
+ }
+
+ if len(r.Tools) > 0 {
+ toolStr, _ := common.Marshal(r.Tools)
+ texts = append(texts, string(toolStr))
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ Files: fileMeta,
+ MaxTokens: int(r.MaxOutputTokens),
+ }
+}
+
+func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
+ return r.Stream
+}
+
type Reasoning struct {
Effort string `json:"effort,omitempty"`
Summary string `json:"summary,omitempty"`
}
-//type ResponsesToolsCall struct {
-// Type string `json:"type"`
-// // Web Search
-// UserLocation json.RawMessage `json:"user_location,omitempty"`
-// SearchContextSize string `json:"search_context_size,omitempty"`
-// // File Search
-// VectorStoreIds []string `json:"vector_store_ids,omitempty"`
-// MaxNumResults uint `json:"max_num_results,omitempty"`
-// Filters json.RawMessage `json:"filters,omitempty"`
-// // Computer Use
-// DisplayWidth uint `json:"display_width,omitempty"`
-// DisplayHeight uint `json:"display_height,omitempty"`
-// Environment string `json:"environment,omitempty"`
-// // Function
-// Name string `json:"name,omitempty"`
-// Description string `json:"description,omitempty"`
-// Parameters json.RawMessage `json:"parameters,omitempty"`
-// Function json.RawMessage `json:"function,omitempty"`
-// Container json.RawMessage `json:"container,omitempty"`
-//}
+type MediaInput struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ FileUrl string `json:"file_url,omitempty"`
+ ImageUrl string `json:"image_url,omitempty"`
+ Detail string `json:"detail,omitempty"` // 仅 input_image 有效
+}
+
+// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
+// Reference implementation mirrors Message.ParseContent:
+// - input can be a string, treated as an input_text item
+// - input can be an array of objects with a `type` field
+// supported types: input_text, input_image, input_file
+func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
+ if r.Input == nil {
+ return nil
+ }
+
+ var inputs []MediaInput
+
+ // Try string first
+ if str, ok := r.Input.(string); ok {
+ inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
+ return inputs
+ }
+
+ // Try array of parts
+ if array, ok := r.Input.([]any); ok {
+ for _, itemAny := range array {
+ // Already parsed MediaInput
+ if media, ok := itemAny.(MediaInput); ok {
+ inputs = append(inputs, media)
+ continue
+ }
+ // Generic map
+ item, ok := itemAny.(map[string]any)
+ if !ok {
+ continue
+ }
+ typeVal, ok := item["type"].(string)
+ if !ok {
+ continue
+ }
+ switch typeVal {
+ case "input_text":
+ text, _ := item["text"].(string)
+ inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
+ case "input_image":
+ // image_url may be string or object with url field
+ var imageUrl string
+ switch v := item["image_url"].(type) {
+ case string:
+ imageUrl = v
+ case map[string]any:
+ if url, ok := v["url"].(string); ok {
+ imageUrl = url
+ }
+ }
+ inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
+ case "input_file":
+ // file_url may be string or object with url field
+ var fileUrl string
+ switch v := item["file_url"].(type) {
+ case string:
+ fileUrl = v
+ case map[string]any:
+ if url, ok := v["url"].(string); ok {
+ fileUrl = url
+ }
+ }
+ inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
+ }
+ }
+ }
+
+ return inputs
+}
diff --git a/dto/request_common.go b/dto/request_common.go
new file mode 100644
index 00000000..e5dde8b5
--- /dev/null
+++ b/dto/request_common.go
@@ -0,0 +1,11 @@
+package dto
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/types"
+)
+
+type Request interface {
+ GetTokenCountMeta() *types.TokenCountMeta
+ IsStream(c *gin.Context) bool
+}
diff --git a/dto/rerank.go b/dto/rerank.go
index 5ea68cba..ca4da9e1 100644
--- a/dto/rerank.go
+++ b/dto/rerank.go
@@ -1,5 +1,12 @@
package dto
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/types"
+ "strings"
+)
+
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
@@ -10,6 +17,26 @@ type RerankRequest struct {
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
+func (r *RerankRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var texts = make([]string, 0)
+
+ for _, document := range r.Documents {
+ texts = append(texts, fmt.Sprintf("%v", document))
+ }
+
+ if r.Query != "" {
+ texts = append(texts, r.Query)
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ }
+}
+
func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil {
return false
diff --git a/logger/logger.go b/logger/logger.go
new file mode 100644
index 00000000..ca81d624
--- /dev/null
+++ b/logger/logger.go
@@ -0,0 +1,115 @@
+package logger
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+ "io"
+ "log"
+ "one-api/common"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+)
+
+const (
+ loggerINFO = "INFO"
+ loggerWarn = "WARN"
+ loggerError = "ERR"
+ loggerDebug = "DEBUG"
+)
+
+const maxLogCount = 1000000
+
+var logCount int
+var setupLogLock sync.Mutex
+var setupLogWorking bool
+
+func SetupLogger() {
+ if *common.LogDir != "" {
+ ok := setupLogLock.TryLock()
+ if !ok {
+ log.Println("setup log is already working")
+ return
+ }
+ defer func() {
+ setupLogLock.Unlock()
+ setupLogWorking = false
+ }()
+ logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+ fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ log.Fatal("failed to open log file")
+ }
+ gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+ gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
+ }
+}
+
+func LogInfo(ctx context.Context, msg string) {
+ logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+ logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+ logHelper(ctx, loggerError, msg)
+}
+
+func LogDebug(ctx context.Context, msg string) {
+ if common.DebugEnabled {
+ logHelper(ctx, loggerDebug, msg)
+ }
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+ writer := gin.DefaultErrorWriter
+ if level == loggerINFO {
+ writer = gin.DefaultWriter
+ }
+ id := ctx.Value(common.RequestIdKey)
+ if id == nil {
+ id = "SYSTEM"
+ }
+ now := time.Now()
+ _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+ logCount++ // we don't need accurate count, so no lock here
+ if logCount > maxLogCount && !setupLogWorking {
+ logCount = 0
+ setupLogWorking = true
+ gopool.Go(func() {
+ SetupLogger()
+ })
+ }
+}
+
+func LogQuota(quota int) string {
+ if common.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d 点额度", quota)
+ }
+}
+
+func FormatQuota(quota int) string {
+ if common.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d", quota)
+ }
+}
+
+// LogJson 仅供测试使用 only for test
+func LogJson(ctx context.Context, msg string, obj any) {
+ jsonStr, err := json.Marshal(obj)
+ if err != nil {
+ LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
+ return
+ }
+ LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
+}
diff --git a/main.go b/main.go
index ca3da601..9a5bd652 100644
--- a/main.go
+++ b/main.go
@@ -8,6 +8,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/controller"
+ "one-api/logger"
"one-api/middleware"
"one-api/model"
"one-api/router"
@@ -35,22 +36,22 @@ func main() {
err := InitResources()
if err != nil {
- common.FatalLog("failed to initialize resources: " + err.Error())
+ logger.FatalLog("failed to initialize resources: " + err.Error())
return
}
- common.SysLog("New API " + common.Version + " started")
+ logger.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
- common.SysLog("running in debug mode")
+ logger.SysLog("running in debug mode")
}
defer func() {
err := model.CloseDB()
if err != nil {
- common.FatalLog("failed to close database: " + err.Error())
+ logger.FatalLog("failed to close database: " + err.Error())
}
}()
@@ -59,18 +60,18 @@ func main() {
common.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
- common.SysLog("memory cache enabled")
- common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+ logger.SysLog("memory cache enabled")
+ logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
// Add panic recovery and retry for InitChannelCache
func() {
defer func() {
if r := recover(); r != nil {
- common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+ logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, _, fixErr := model.FixAbility()
if fixErr != nil {
- common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+ logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
@@ -89,14 +90,14 @@ func main() {
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
@@ -110,7 +111,7 @@ func main() {
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
- common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
@@ -119,13 +120,13 @@ func main() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
})
go common.Monitor()
- common.SysLog("pprof enabled")
+ logger.SysLog("pprof enabled")
}
// Initialize HTTP server
server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
+ logger.SysError(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
@@ -155,7 +156,7 @@ func main() {
}
err = server.Run(":" + port)
if err != nil {
- common.FatalLog("failed to start HTTP server: " + err.Error())
+ logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}
@@ -164,14 +165,14 @@ func InitResources() error {
// This is a placeholder function for future resource initialization
err := godotenv.Load(".env")
if err != nil {
- common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
- common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+ logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
+ logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
}
// 加载环境变量
common.InitEnv()
- common.SetupLogger()
+ logger.SetupLogger()
// Initialize model settings
ratio_setting.InitRatioSettings()
@@ -183,7 +184,7 @@ func InitResources() error {
// Initialize SQL Database
err = model.InitDB()
if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
+ logger.FatalLog("failed to initialize database: " + err.Error())
return err
}
diff --git a/middleware/recover.go b/middleware/recover.go
index 51fc7190..6c9c7ef6 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/logger"
"runtime/debug"
)
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
- common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ logger.SysError(fmt.Sprintf("panic detected: %v", err))
+ logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 26688810..a136a900 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/logger"
)
type turnstileCheckResponse struct {
@@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc {
"remoteip": {c.ClientIP()},
})
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/middleware/utils.go b/middleware/utils.go
index 082f5657..e23bbff7 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/logger"
)
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
@@ -15,7 +16,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
},
})
c.Abort()
- common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
+ logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
@@ -25,5 +26,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri
"code": code,
})
c.Abort()
- common.LogError(c.Request.Context(), description)
+ logger.LogError(c.Request.Context(), description)
}
diff --git a/model/ability.go b/model/ability.go
index ce2f299c..ac5530d8 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"strings"
"sync"
@@ -294,13 +295,13 @@ func FixAbility() (int, int, error) {
if common.UsingSQLite {
err := DB.Exec("DELETE FROM abilities").Error
if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
return 0, 0, err
}
} else {
err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil {
- common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
return 0, 0, err
}
}
@@ -320,7 +321,7 @@ func FixAbility() (int, int, error) {
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
@@ -328,7 +329,7 @@ func FixAbility() (int, int, error) {
for _, channel := range chunk {
err = channel.AddAbilities(nil)
if err != nil {
- common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+ logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++
diff --git a/model/channel.go b/model/channel.go
index 6239f05c..c0d253fc 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/types"
"strings"
"sync"
@@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
if channel.OtherInfo != "" {
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
- common.SysError("failed to unmarshal other info: " + err.Error())
+ logger.SysError("failed to unmarshal other info: " + err.Error())
}
}
return otherInfo
@@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
- common.SysError("failed to marshal other info: " + err.Error())
+ logger.SysError("failed to marshal other info: " + err.Error())
return
}
channel.OtherInfo = string(otherInfoBytes)
@@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
ResponseTime: int(responseTime),
}).Error
if err != nil {
- common.SysError("failed to update response time: " + err.Error())
+ logger.SysError("failed to update response time: " + err.Error())
}
}
@@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
Balance: balance,
}).Error
if err != nil {
- common.SysError("failed to update balance: " + err.Error())
+ logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil {
- common.SysError("failed to update ability status: " + err.Error())
+ logger.SysError("failed to update ability status: " + err.Error())
}
}
}()
@@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
err = channel.Save()
if err != nil {
- common.SysError("failed to update channel status: " + err.Error())
+ logger.SysError("failed to update channel status: " + err.Error())
return false
}
}
@@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
for _, channel := range channels {
err = channel.UpdateAbilities(nil)
if err != nil {
- common.SysError("failed to update abilities: " + err.Error())
+ logger.SysError("failed to update abilities: " + err.Error())
}
}
}
@@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
- common.SysError("failed to update channel used quota: " + err.Error())
+ logger.SysError("failed to update channel used quota: " + err.Error())
}
}
@@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
if channel.Setting != nil && *channel.Setting != "" {
err := common.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
- common.SysError("failed to unmarshal setting: " + err.Error())
+ logger.SysError("failed to unmarshal setting: " + err.Error())
channel.Setting = nil // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
}
@@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := common.Marshal(setting)
if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
+ logger.SysError("failed to marshal setting: " + err.Error())
return
}
channel.Setting = common.GetPointer[string](string(settingBytes))
@@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
if channel.OtherSettings != "" {
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
if err != nil {
- common.SysError("failed to unmarshal setting: " + err.Error())
+ logger.SysError("failed to unmarshal setting: " + err.Error())
channel.OtherSettings = "{}" // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
}
@@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
settingBytes, err := common.Marshal(setting)
if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
+ logger.SysError("failed to marshal setting: " + err.Error())
return
}
channel.OtherSettings = string(settingBytes)
@@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
if err != nil {
- common.SysError("failed to unmarshal param override: " + err.Error())
+ logger.SysError("failed to unmarshal param override: " + err.Error())
}
}
return paramOverride
diff --git a/model/channel_cache.go b/model/channel_cache.go
index 86866e40..22216027 100644
--- a/model/channel_cache.go
+++ b/model/channel_cache.go
@@ -6,6 +6,7 @@ import (
"math/rand"
"one-api/common"
"one-api/constant"
+ "one-api/logger"
"one-api/setting"
"one-api/setting/ratio_setting"
"sort"
@@ -84,13 +85,13 @@ func InitChannelCache() {
}
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
+ logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
+ logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
diff --git a/model/log.go b/model/log.go
index 2070cd6f..d9495968 100644
--- a/model/log.go
+++ b/model/log.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
+ "one-api/logger"
"os"
"strings"
"time"
@@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) {
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.SysError("failed to record log: " + err.Error())
+ logger.SysError("failed to record log: " + err.Error())
}
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
- common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
+ logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
@@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.LogError(c, "failed to record log: "+err.Error())
+ logger.LogError(c, "failed to record log: "+err.Error())
}
}
@@ -142,7 +143,6 @@ type RecordConsumeLogParams struct {
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
- UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
@@ -150,7 +150,7 @@ type RecordConsumeLogParams struct {
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
- common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
+ logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled {
return
}
@@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.LogError(c, "failed to record log: "+err.Error())
+ logger.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
diff --git a/model/main.go b/model/main.go
index dbf27152..1e582e1a 100644
--- a/model/main.go
+++ b/model/main.go
@@ -5,6 +5,7 @@ import (
"log"
"one-api/common"
"one-api/constant"
+ "one-api/logger"
"os"
"strings"
"sync"
@@ -84,7 +85,7 @@ func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
- common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+ logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -108,7 +109,7 @@ func CheckSetup() {
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
- common.SysLog("system is not initialized, but root user exists")
+ logger.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
@@ -116,16 +117,16 @@ func CheckSetup() {
}
err := DB.Create(&newSetup).Error
if err != nil {
- common.SysLog("failed to create setup record: " + err.Error())
+ logger.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
- common.SysLog("system is not initialized and no root user exists")
+ logger.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
- common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
+ logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
@@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
- common.SysLog("using PostgreSQL as database")
+ logger.SysLog("using PostgreSQL as database")
if !isLog {
common.UsingPostgreSQL = true
} else {
@@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
if strings.HasPrefix(dsn, "local") {
- common.SysLog("SQL_DSN not set, using SQLite as database")
+ logger.SysLog("SQL_DSN not set, using SQLite as database")
if !isLog {
common.UsingSQLite = true
} else {
@@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
// Use MySQL
- common.SysLog("using MySQL as database")
+ logger.SysLog("using MySQL as database")
// check parseTime
if !strings.Contains(dsn, "parseTime") {
if strings.Contains(dsn, "?") {
@@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
// Use SQLite
- common.SysLog("SQL_DSN not set, using SQLite as database")
+ logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
@@ -216,11 +217,11 @@ func InitDB() (err error) {
if common.UsingMySQL {
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
- common.SysLog("database migration started")
+ logger.SysLog("database migration started")
err = migrateDB()
return err
} else {
- common.FatalLog(err)
+ logger.FatalLog(err)
}
return err
}
@@ -253,11 +254,11 @@ func InitLogDB() (err error) {
if !common.IsMasterNode {
return nil
}
- common.SysLog("database migration started")
+ logger.SysLog("database migration started")
err = migrateLOGDB()
return err
} else {
- common.FatalLog(err)
+ logger.FatalLog(err)
}
return err
}
@@ -354,7 +355,7 @@ func migrateDBFast() error {
return err
}
}
- common.SysLog("database migrated")
+ logger.SysLog("database migrated")
return nil
}
@@ -503,6 +504,6 @@ func PingDB() error {
}
lastPingTime = time.Now()
- common.SysLog("Database pinged successfully")
+ logger.SysLog("Database pinged successfully")
return nil
}
diff --git a/model/option.go b/model/option.go
index 5c84d166..8fcd13a8 100644
--- a/model/option.go
+++ b/model/option.go
@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
+ "one-api/logger"
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
@@ -150,7 +151,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
- common.SysError("failed to update option map: " + err.Error())
+ logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -158,7 +159,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing options from database")
+ logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
diff --git a/model/pricing.go b/model/pricing.go
index 0936d298..31aa5cdf 100644
--- a/model/pricing.go
+++ b/model/pricing.go
@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
+ "one-api/logger"
"strings"
"one-api/common"
@@ -92,7 +93,7 @@ func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
- common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
+ logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
// 预加载模型元数据与供应商一次,避免循环查询
diff --git a/model/redemption.go b/model/redemption.go
index bf237668..1ab84f45 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"strconv"
"gorm.io/gorm"
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil
}
diff --git a/model/token.go b/model/token.go
index e85a445e..63c17e2d 100644
--- a/model/token.go
+++ b/model/token.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"strings"
"github.com/bytedance/gopkg/util/gopool"
@@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
return token, errors.New("该令牌已过期")
@@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
keyPrefix := key[:3]
@@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
if err := cacheSetToken(token); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ logger.SysError("failed to update user status cache: " + err.Error())
}
})
}
@@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() {
if err := cacheSetToken(*token); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ logger.SysError("failed to update user status cache: " + err.Error())
}
})
}
@@ -178,7 +179,7 @@ func (token *Token) Update() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
- common.SysError("failed to update token cache: " + err.Error())
+ logger.SysError("failed to update token cache: " + err.Error())
}
})
}
@@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
- common.SysError("failed to update token cache: " + err.Error())
+ logger.SysError("failed to update token cache: " + err.Error())
}
})
}
@@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) {
gopool.Go(func() {
err := cacheDeleteToken(token.Key)
if err != nil {
- common.SysError("failed to delete token cache: " + err.Error())
+ logger.SysError("failed to delete token cache: " + err.Error())
}
})
}
@@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota))
if err != nil {
- common.SysError("failed to increase token quota: " + err.Error())
+ logger.SysError("failed to increase token quota: " + err.Error())
}
})
}
@@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota))
if err != nil {
- common.SysError("failed to decrease token quota: " + err.Error())
+ logger.SysError("failed to decrease token quota: " + err.Error())
}
})
}
diff --git a/model/topup.go b/model/topup.go
index c34c0ce6..802c866f 100644
--- a/model/topup.go
+++ b/model/topup.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"gorm.io/gorm"
)
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
return errors.New("充值失败," + err.Error())
}
- RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
+ RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
return nil
}
diff --git a/model/twofa.go b/model/twofa.go
index d09ff9fe..b2ea54e0 100644
--- a/model/twofa.go
+++ b/model/twofa.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"time"
"gorm.io/gorm"
@@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
- common.SysError("更新2FA失败次数失败: " + err.Error())
+ logger.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
@@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now
if err := t.Update(); err != nil {
- common.SysError("更新2FA使用记录失败: " + err.Error())
+ logger.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
@@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
if !valid {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
- common.SysError("更新2FA失败次数失败: " + err.Error())
+ logger.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
@@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now
if err := t.Update(); err != nil {
- common.SysError("更新2FA使用记录失败: " + err.Error())
+ logger.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
diff --git a/model/usedata.go b/model/usedata.go
index 1255b0be..f0027a8d 100644
--- a/model/usedata.go
+++ b/model/usedata.go
@@ -4,6 +4,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/logger"
"sync"
"time"
)
@@ -24,12 +25,12 @@ func UpdateQuotaData() {
// recover
defer func() {
if r := recover(); r != nil {
- common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
+ logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
}
}()
for {
if common.DataExportEnabled {
- common.SysLog("正在更新数据看板数据...")
+ logger.SysLog("正在更新数据看板数据...")
SaveQuotaDataCache()
}
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
@@ -91,7 +92,7 @@ func SaveQuotaDataCache() {
}
}
CacheQuotaData = make(map[string]*QuotaData)
- common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
+ logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
}
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
@@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int,
"token_used": gorm.Expr("token_used + ?", tokenUsed),
}).Error
if err != nil {
- common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
+ logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
}
}
diff --git a/model/user.go b/model/user.go
index 6021f495..244380ad 100644
--- a/model/user.go
+++ b/model/user.go
@@ -6,6 +6,7 @@ import (
"fmt"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"strconv"
"strings"
@@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
- common.SysError("failed to unmarshal setting: " + err.Error())
+ logger.SysError("failed to unmarshal setting: " + err.Error())
}
}
return setting
@@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
+ logger.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
@@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
func (user *User) TransferAffQuotaToQuota(quota int) error {
// 检查quota是否小于最小额度
if float64(quota) < common.QuotaPerUnit {
- return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit)))
+ return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
}
// 开始数据库事务
@@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId)
}
}
@@ -517,7 +518,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
- common.SysError("no such user " + err.Error())
+ logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
- common.SysError("failed to update user quota cache: " + err.Error())
+ logger.SysError("failed to update user quota cache: " + err.Error())
}
})
}
@@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
- common.SysError("failed to update user group cache: " + err.Error())
+ logger.SysError("failed to update user group cache: " + err.Error())
}
})
}
@@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
- common.SysError("failed to update user setting cache: " + err.Error())
+ logger.SysError("failed to update user setting cache: " + err.Error())
}
})
}
@@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
- common.SysError("failed to increase user quota: " + err.Error())
+ logger.SysError("failed to increase user quota: " + err.Error())
}
})
if !db && common.BatchUpdateEnabled {
@@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
- common.SysError("failed to decrease user quota: " + err.Error())
+ logger.SysError("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
@@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota and request count: " + err.Error())
+ logger.SysError("failed to update user used quota and request count: " + err.Error())
return
}
@@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota: " + err.Error())
+ logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
- common.SysError("failed to update user request count: " + err.Error())
+ logger.SysError("failed to update user request count: " + err.Error())
}
}
@@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
- common.SysError("failed to update user name cache: " + err.Error())
+ logger.SysError("failed to update user name cache: " + err.Error())
}
})
}
diff --git a/model/user_cache.go b/model/user_cache.go
index a631457c..dec7597b 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -5,6 +5,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"time"
"github.com/gin-gonic/gin"
@@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
if user.Setting != "" {
err := common.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
- common.SysError("failed to unmarshal setting: " + err.Error())
+ logger.SysError("failed to unmarshal setting: " + err.Error())
}
}
return setting
@@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ logger.SysError("failed to update user status cache: " + err.Error())
}
})
}
diff --git a/model/utils.go b/model/utils.go
index 1f8a0963..abd96b79 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -3,6 +3,7 @@ package model
import (
"errors"
"one-api/common"
+ "one-api/logger"
"sync"
"time"
@@ -65,7 +66,7 @@ func batchUpdate() {
return
}
- common.SysLog("batch update started")
+ logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
@@ -77,12 +78,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
- common.SysError("failed to batch update user quota: " + err.Error())
+ logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
- common.SysError("failed to batch update token quota: " + err.Error())
+ logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
@@ -93,7 +94,7 @@ func batchUpdate() {
}
}
}
- common.SysLog("batch update finished")
+ logger.SysLog("batch update finished")
}
func RecordExist(err error) (bool, error) {
diff --git a/relay/audio_handler.go b/relay/audio_handler.go
index 88777838..1bc2d90b 100644
--- a/relay/audio_handler.go
+++ b/relay/audio_handler.go
@@ -4,107 +4,40 @@ import (
"errors"
"fmt"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
- "one-api/setting"
"one-api/types"
- "strings"
"github.com/gin-gonic/gin"
)
-func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
- audioRequest := &dto.AudioRequest{}
- err := common.UnmarshalBodyReusable(c, audioRequest)
- if err != nil {
- return nil, err
- }
- switch info.RelayMode {
- case relayconstant.RelayModeAudioSpeech:
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if setting.ShouldCheckPromptSensitive() {
- words, err := service.CheckSensitiveInput(audioRequest.Input)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
- return nil, err
- }
- }
- default:
- err = c.Request.ParseForm()
- if err != nil {
- return nil, err
- }
- formData := c.Request.PostForm
- if audioRequest.Model == "" {
- audioRequest.Model = formData.Get("model")
- }
+func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- audioRequest.ResponseFormat = formData.Get("response_format")
- if audioRequest.ResponseFormat == "" {
- audioRequest.ResponseFormat = "json"
- }
- }
- return audioRequest, nil
-}
-
-func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
- relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
- audioRequest, err := getAndValidAudioRequest(c, relayInfo)
-
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ audioRequest, ok := info.Request.(*dto.AudioRequest)
+ if !ok {
+ return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
- promptTokens := 0
- preConsumedTokens := common.PreConsumedQuota
- if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
- promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
- preConsumedTokens = promptTokens
- relayInfo.PromptTokens = promptTokens
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
-
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
+ err := helper.ModelMappedHelper(c, info, audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
- ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
+ ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
- resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
+ resp, err := adaptor.DoRequest(c, info, ioReader)
if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
@@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go
index 754f29c8..841896cf 100644
--- a/relay/channel/ali/image.go
+++ b/relay/channel/ali/image.go
@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
@@ -43,7 +43,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
- common.SysError("updateTask client.Do err: " + err.Error())
+ logger.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
@@ -53,7 +53,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
- common.SysError("updateTask NewDecoder err: " + err.Error())
+ logger.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
@@ -109,7 +109,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
- common.LogError(c, "get_image_data_failed: "+err.Error())
+ logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
@@ -134,14 +134,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
- common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
+ logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go
index 4f448e01..e7d6b514 100644
--- a/relay/channel/ali/rerank.go
+++ b/relay/channel/ali/rerank.go
@@ -4,9 +4,9 @@ import (
"encoding/json"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go
index fcf63854..17fcef2a 100644
--- a/relay/channel/ali/text.go
+++ b/relay/channel/ali/text.go
@@ -7,7 +7,9 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/helper"
+ "one-api/service"
"strings"
"one-api/types"
@@ -46,7 +48,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
@@ -148,7 +150,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -161,7 +163,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -171,7 +173,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false
}
})
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return nil, &usage
}
@@ -181,7 +183,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index 3ccd2d78..fd745cf7 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
common2 "one-api/common"
+ "one-api/logger"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
@@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
err := helper.PingData(c)
if err != nil {
- common2.LogError(c, "SSE ping error: "+err.Error())
+ logger.LogError(c, "SSE ping error: "+err.Error())
done <- err
return
}
diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go
index a7cd5996..696c2496 100644
--- a/relay/channel/baidu/relay-baidu.go
+++ b/relay/channel/baidu/relay-baidu.go
@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -118,7 +119,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -129,11 +130,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
- common.SysError("error sending stream response: " + err.Error())
+ logger.SysError("error sending stream response: " + err.Error())
}
return true
})
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -143,7 +144,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
@@ -168,7 +169,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index e4d3975e..5d839908 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
- common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+ logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -609,7 +610,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
@@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
err = helper.ObjectData(c, response)
if err != nil {
- common.LogError(c, "send_stream_response_failed: "+err.Error())
+ logger.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return nil
@@ -653,7 +654,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
- common.SysError("claude response usage is not complete, maybe upstream error")
+ logger.SysError("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
@@ -667,7 +668,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
- common.SysError("send final response failed: " + err.Error())
+ logger.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
@@ -736,12 +737,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
}
- common.IOCopyBytesGracefully(c, nil, responseData)
+ service.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
index 5e8fe7f9..00f6b6c5 100644
--- a/relay/channel/cloudflare/relay_cloudflare.go
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -5,8 +5,8 @@ import (
"encoding/json"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
- common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+ logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
@@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
info.FirstResponseTime = time.Now()
}
if err != nil {
- common.LogError(c, "error_rendering_stream_response: "+err.Error())
+ logger.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
- common.LogError(c, "error_scanning_stream_response: "+err.Error())
+ logger.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
- common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+ logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
helper.Done(c)
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
@@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go
index fcfb12b7..ccef9b23 100644
--- a/relay/channel/cohere/relay-cohere.go
+++ b/relay/channel/cohere/relay-cohere.go
@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -118,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
@@ -153,7 +154,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -175,7 +176,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
@@ -216,7 +217,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go
index 32cc6937..18ed46af 100644
--- a/relay/channel/coze/relay-coze.go
+++ b/relay/channel/coze/relay-coze.go
@@ -9,6 +9,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -49,7 +50,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
@@ -154,7 +155,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -171,14 +172,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -203,11 +204,11 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
- common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+ logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go
index 47337127..f03d61a4 100644
--- a/relay/channel/dify/relay-dify.go
+++ b/relay/channel/dify/relay-dify.go
@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -36,14 +37,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
- common.SysError("failed to decode base64: " + err.Error())
+ logger.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
- common.SysError("failed to create temp file: " + err.Error())
+ logger.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
@@ -51,7 +52,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
- common.SysError("failed to write to temp file: " + err.Error())
+ logger.SysError("failed to write to temp file: " + err.Error())
return nil
}
@@ -61,7 +62,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Add user field
if err := writer.WriteField("user", user); err != nil {
- common.SysError("failed to add user field: " + err.Error())
+ logger.SysError("failed to add user field: " + err.Error())
return nil
}
@@ -74,13 +75,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
- common.SysError("failed to create form file: " + err.Error())
+ logger.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
- common.SysError("failed to copy file content: " + err.Error())
+ logger.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
@@ -88,7 +89,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
- common.SysError("failed to create request: " + err.Error())
+ logger.SysError("failed to create request: " + err.Error())
return nil
}
@@ -99,7 +100,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
- common.SysError("failed to send request: " + err.Error())
+ logger.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
@@ -109,7 +110,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- common.SysError("failed to decode response: " + err.Error())
+ logger.SysError("failed to decode response: " + err.Error())
return nil
}
@@ -219,7 +220,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
@@ -239,7 +240,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
}
return true
})
@@ -258,7 +259,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 4141caf7..05d974f6 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
},
},
Parameters: dto.GeminiImageParameters{
- SampleCount: request.N,
+ SampleCount: int(request.N),
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
index 7f2f51fb..974a22f5 100644
--- a/relay/channel/gemini/relay-gemini-native.go
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -17,7 +18,7 @@ import (
)
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
@@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
}
}
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
@@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
- common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
@@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
// 直接发送 GeminiChatResponse 响应
err = helper.StringData(c, data)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
info.SendResponseCount++
return true
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 58efa1a5..82a2d8de 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
- common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
@@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
finishReason = constant.FinishReasonToolCalls
err = handleStream(c, info, emptyResponse)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
response.ClearToolCalls()
@@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
err = handleStream(c, info, response)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
if isStop {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
@@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response)
if err != nil {
- common.SysError("send final response failed: " + err.Error())
+ logger.SysError("send final response failed: " + err.Error())
}
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c)
@@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
@@ -1057,13 +1058,13 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
break
}
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
@@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go
index 28af1866..11a0117b 100644
--- a/relay/channel/jimeng/image.go
+++ b/relay/channel/jimeng/image.go
@@ -5,9 +5,9 @@ import (
"fmt"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go
index c9db6630..d8b598dc 100644
--- a/relay/channel/jimeng/sign.go
+++ b/relay/channel/jimeng/sign.go
@@ -12,7 +12,7 @@ import (
"io"
"net/http"
"net/url"
- "one-api/common"
+ "one-api/logger"
"sort"
"strings"
"time"
@@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error {
if err != nil {
return err
}
- common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+ logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
payloadHash := sha256.Sum256(body)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
c.Set(HexPayloadHashKey, hexPayloadHash)
diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go
index 78f96d6d..d91aceb3 100644
--- a/relay/channel/mokaai/relay-mokaai.go
+++ b/relay/channel/mokaai/relay-mokaai.go
@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
@@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
index d4686ce3..066581fa 100644
--- a/relay/channel/ollama/relay-ollama.go
+++ b/relay/channel/ollama/relay-ollama.go
@@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- common.IOCopyBytesGracefully(c, resp, doResponseBody)
+ service.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil
}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
index 696c5cb0..80973aa1 100644
--- a/relay/channel/openai/helper.go
+++ b/relay/channel/openai/helper.go
@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
@@ -50,7 +51,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
- common.LogError(c, "failed to unmarshal stream response: "+err.Error())
+ logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err
}
@@ -63,7 +64,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
- common.LogError(c, "failed to marshal gemini response: "+err.Error())
+ logger.LogError(c, "failed to marshal gemini response: "+err.Error())
return err
}
@@ -110,14 +111,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
- common.SysError("error processing stream response: " + err.Error())
+ logger.SysError("error processing stream response: " + err.Error())
}
}
return nil
@@ -146,7 +147,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -213,7 +214,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -227,7 +228,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
case relaycommon.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -245,7 +246,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
- common.SysError("error marshalling gemini response: " + err.Error())
+ logger.SysError("error marshalling gemini response: " + err.Error())
return
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index b8e72273..447e0f31 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
- common.LogError(c, "invalid response or response body")
+ logger.LogError(c, "invalid response or response body")
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string
@@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
- common.SysError("error handling stream format: " + err.Error())
+ logger.SysError("error handling stream format: " + err.Error())
}
}
if len(data) > 0 {
@@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
- common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
+ logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
@@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
- common.LogError(c, "error processing tokens: "+err.Error())
+ logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
@@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
@@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
responseBody = geminiRespStr
}
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
return &simpleResponse.Usage, nil
}
@@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
@@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
@@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
@@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
@@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
@@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
localUsage = &dto.RealtimeUsage{}
// print now usage
}
- common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
- common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
@@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
@@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
- common.LogError(c, "realtime error: "+err.Error())
+ logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
@@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
}
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
}
// 写入新的 response body
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
index bae6fcb6..754a6f44 100644
--- a/relay/channel/openai/relay_responses.go
+++ b/relay/channel/openai/relay_responses.go
@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -16,7 +17,7 @@ import (
)
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
// read response body
var responsesResponse dto.OpenAIResponsesResponse
@@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
// 写入新的 response body
- common.IOCopyBytesGracefully(c, resp, responseBody)
+ service.IOCopyBytesGracefully(c, resp, responseBody)
// compute usage
usage := dto.Usage{}
@@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
- common.LogError(c, "invalid response or response body")
+ logger.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index 9b8bce7d..1264b2b4 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -58,15 +59,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -78,7 +79,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -96,7 +97,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false
}
})
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return nil, responseText
}
@@ -105,7 +106,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
@@ -133,6 +134,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}
diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go
index 2e37ad15..b21faccb 100644
--- a/relay/channel/siliconflow/relay-siliconflow.go
+++ b/relay/channel/siliconflow/relay-siliconflow.go
@@ -4,9 +4,9 @@ import (
"encoding/json"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
@@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index 9c04c7ad..1deb33fd 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
- common.SysError(fmt.Sprintf("Get Task error: %v", err))
+ logger.SysError(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()
diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go
index 78ce6238..d3aeab3f 100644
--- a/relay/channel/tencent/relay-tencent.go
+++ b/relay/channel/tencent/relay-tencent.go
@@ -13,6 +13,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -106,7 +107,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
@@ -117,17 +118,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
err = helper.ObjectData(c, response)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
- common.SysError("error reading stream: " + err.Error())
+ logger.SysError("error reading stream: " + err.Error())
}
helper.Done(c)
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
}
@@ -138,7 +139,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -156,7 +157,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}
diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go
index 4d098102..4d4e7b92 100644
--- a/relay/channel/xai/text.go
+++ b/relay/channel/xai/text.go
@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -47,7 +48,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
@@ -63,7 +64,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
}
return true
})
@@ -74,12 +75,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
helper.Done(c)
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return usage, nil
}
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
+ defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -101,7 +102,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- common.IOCopyBytesGracefully(c, resp, encodeJson)
+ service.IOCopyBytesGracefully(c, resp, encodeJson)
return xaiResponse.Usage, nil
}
diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go
index 1a426d50..398bb08d 100644
--- a/relay/channel/xunfei/relay-xunfei.go
+++ b/relay/channel/xunfei/relay-xunfei.go
@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/helper"
"one-api/types"
"strings"
@@ -143,7 +144,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -218,20 +219,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
for {
_, msg, err := conn.ReadMessage()
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
- common.SysError("error closing websocket connection: " + err.Error())
+ logger.SysError("error closing websocket connection: " + err.Error())
}
break
}
@@ -282,6 +283,6 @@ func getAPIVersion(c *gin.Context, modelName string) string {
return apiVersion
}
apiVersion = "v1.1"
- common.SysLog("api_version not found, using default: " + apiVersion)
+ logger.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}
diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go
index 35882ed5..65b662b6 100644
--- a/relay/channel/zhipu/relay-zhipu.go
+++ b/relay/channel/zhipu/relay-zhipu.go
@@ -8,8 +8,10 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
+ "one-api/service"
"one-api/types"
"strings"
"sync"
@@ -38,7 +40,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
+ logger.SysError("invalid zhipu key: " + apikey)
return ""
}
@@ -186,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -195,13 +197,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
@@ -212,7 +214,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
return false
}
})
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
return usage, nil
}
@@ -222,7 +224,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
diff --git a/relay/relay-mj.go b/relay/chat_handler.go
similarity index 98%
rename from relay/relay-mj.go
rename to relay/chat_handler.go
index e7f316b9..30bce55c 100644
--- a/relay/relay-mj.go
+++ b/relay/chat_handler.go
@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -214,7 +215,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
@@ -300,7 +301,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
- common.IOCopyBytesGracefully(c, nil, respBody)
+ service.IOCopyBytesGracefully(c, nil, respBody)
return nil
}
@@ -521,7 +522,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
@@ -572,7 +573,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//无实例账号自动禁用渠道(No available account instance)
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
- common.SysError("get_channel_null: " + err.Error())
+ logger.SysError("get_channel_null: " + err.Error())
}
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index b4bf78ff..ddc424b4 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
- "errors"
"fmt"
"io"
"net/http"
@@ -18,68 +17,26 @@ import (
"github.com/gin-gonic/gin"
)
-func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
- textRequest = &dto.ClaudeRequest{}
- err = c.ShouldBindJSON(textRequest)
- if err != nil {
- return nil, err
- }
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- if textRequest.Model == "" {
- return nil, errors.New("field model is required")
- }
- return textRequest, nil
-}
+func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
- relayInfo := relaycommon.GenRelayInfoClaude(c)
+ textRequest, ok := info.Request.(*dto.ClaudeRequest)
- // get & validate textRequest 获取并验证文本请求
- textRequest, err := getAndValidateClaudeRequest(c)
- if err != nil {
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
}
- if textRequest.Stream {
- relayInfo.IsStream = true
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+ err := helper.ModelMappedHelper(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
- // count messages token error 计算promptTokens错误
- if err != nil {
- return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
-
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
@@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
textRequest.Temperature = common.GetPointer[float64](1.0)
}
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
- relayInfo.UpstreamModelName = textRequest.Model
+ info.UpstreamModelName = textRequest.Model
}
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
- convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
+ convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
if resp != nil {
httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
//log.Printf("usage: %v", usage)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
- service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+
+ service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}
-
-func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
- var promptTokens int
- var err error
- switch info.RelayMode {
- default:
- promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
- }
- info.PromptTokens = promptTokens
- return promptTokens, err
-}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 5cd9223b..59be0011 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -1,10 +1,12 @@
package common
import (
+ "errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
+ "one-api/types"
"strings"
"time"
@@ -33,17 +35,6 @@ type ClaudeConvertInfo struct {
Done bool
}
-const (
- RelayFormatOpenAI = "openai"
- RelayFormatClaude = "claude"
- RelayFormatGemini = "gemini"
- RelayFormatOpenAIResponses = "openai_responses"
- RelayFormatOpenAIAudio = "openai_audio"
- RelayFormatOpenAIImage = "openai_image"
- RelayFormatRerank = "rerank"
- RelayFormatEmbedding = "embedding"
-)
-
type RerankerInfo struct {
Documents []any
ReturnDocuments bool
@@ -59,61 +50,103 @@ type ResponsesUsageInfo struct {
BuiltInTools map[string]*BuildInToolInfo
}
-type RelayInfo struct {
+type ChannelMeta struct {
ChannelType int
ChannelId int
- ChannelIsMultiKey bool // 是否多密钥
- ChannelMultiKeyIndex int // 多密钥索引
- TokenId int
- TokenKey string
- UserId int
- UsingGroup string // 使用的分组
- UserGroup string // 用户所在分组
- TokenUnlimited bool
- StartTime time.Time
- FirstResponseTime time.Time
- isFirstResponse bool
+ ChannelIsMultiKey bool
+ ChannelMultiKeyIndex int
+ ChannelBaseUrl string
+ ApiType int
+ ApiVersion string
+ ApiKey string
+ Organization string
+ ChannelCreateTime int64
+ ParamOverride map[string]interface{}
+ ChannelSetting dto.ChannelSettings
+ ChannelOtherSettings dto.ChannelOtherSettings
+ UpstreamModelName string
+ IsModelMapped bool
+}
+
+type RelayInfo struct {
+ TokenId int
+ TokenKey string
+ UserId int
+ UsingGroup string // 使用的分组
+ UserGroup string // 用户所在分组
+ TokenUnlimited bool
+ StartTime time.Time
+ FirstResponseTime time.Time
+ isFirstResponse bool
//SendLastReasoningResponse bool
- ApiType int
IsStream bool
IsGeminiBatchEmbedding bool
IsPlayground bool
UsePrice bool
RelayMode int
- UpstreamModelName string
OriginModelName string
//RecodeModelName string
- RequestURLPath string
- ApiVersion string
- PromptTokens int
- ApiKey string
- Organization string
- BaseUrl string
- SupportStreamOptions bool
- ShouldIncludeUsage bool
- DisablePing bool // 是否禁止向下游发送自定义 Ping
- IsModelMapped bool
- ClientWs *websocket.Conn
- TargetWs *websocket.Conn
- InputAudioFormat string
- OutputAudioFormat string
- RealtimeTools []dto.RealTimeTool
- IsFirstRequest bool
- AudioUsage bool
- ReasoningEffort string
- ChannelSetting dto.ChannelSettings
- ChannelOtherSettings dto.ChannelOtherSettings
- ParamOverride map[string]interface{}
- UserSetting dto.UserSetting
- UserEmail string
- UserQuota int
- RelayFormat string
- SendResponseCount int
- ChannelCreateTime int64
+ RequestURLPath string
+ PromptTokens int
+ SupportStreamOptions bool
+ ShouldIncludeUsage bool
+ DisablePing bool // 是否禁止向下游发送自定义 Ping
+ ClientWs *websocket.Conn
+ TargetWs *websocket.Conn
+ InputAudioFormat string
+ OutputAudioFormat string
+ RealtimeTools []dto.RealTimeTool
+ IsFirstRequest bool
+ AudioUsage bool
+ ReasoningEffort string
+ UserSetting dto.UserSetting
+ UserEmail string
+ UserQuota int
+ RelayFormat types.RelayFormat
+ SendResponseCount int
+ FinalPreConsumedQuota int // 最终预消耗的配额
+
+ PriceData types.PriceData
+
+ Request dto.Request
+
ThinkingContentInfo
*ClaudeConvertInfo
*RerankerInfo
*ResponsesUsageInfo
+ *ChannelMeta
+}
+
+func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
+ channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+ paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+ apiType, _ := common.ChannelType2APIType(channelType)
+ channelMeta := &ChannelMeta{
+ ChannelType: channelType,
+ ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
+ ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
+ ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
+ ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
+ ApiType: apiType,
+ ApiVersion: c.GetString("api_version"),
+ ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
+ Organization: c.GetString("channel_organization"),
+ ChannelCreateTime: c.GetInt64("channel_create_time"),
+ ParamOverride: paramOverride,
+ UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ IsModelMapped: false,
+ }
+
+ channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
+ if ok {
+ channelMeta.ChannelSetting = channelSetting
+ }
+
+ channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
+ if ok {
+ channelMeta.ChannelOtherSettings = channelOtherSettings
+ }
+ info.ChannelMeta = channelMeta
}
// 定义支持流式选项的通道类型
@@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
- info := GenRelayInfo(c)
+ info := genBaseRelayInfo(c, nil)
+ info.RelayFormat = types.RelayFormatOpenAIRealtime
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
@@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
return info
}
-func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatClaude
+func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
@@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info
}
-func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
- info := GenRelayInfo(c)
+func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeRerank
- info.RelayFormat = RelayFormatRerank
+ info.RelayFormat = types.RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
- Documents: req.Documents,
- ReturnDocuments: req.GetReturnDocuments(),
+ Documents: request.Documents,
+ ReturnDocuments: request.GetReturnDocuments(),
}
return info
}
-func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatOpenAIAudio
+func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAIAudio
return info
}
-func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatEmbedding
+func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatEmbedding
return info
}
-func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
- info := GenRelayInfo(c)
+func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeResponses
- info.RelayFormat = RelayFormatOpenAIResponses
+ info.RelayFormat = types.RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
- if len(req.Tools) > 0 {
- for _, tool := range req.Tools {
+ if len(request.Tools) > 0 {
+ for _, tool := range request.Tools {
toolType := common.Interface2String(tool["type"])
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
ToolName: toolType,
@@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
}
}
}
- info.IsStream = req.Stream
return info
}
-func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatGemini
+func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatGemini
info.ShouldIncludeUsage = false
+
return info
}
-func GenRelayInfoImage(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatOpenAIImage
+func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAIImage
return info
}
-func GenRelayInfo(c *gin.Context) *RelayInfo {
- channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
- channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
- paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAI
+ return info
+}
+
+func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
+
+ //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+ //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
+ //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
- tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
- tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
- userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
- tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
+
// firstResponseTime = time.Now() - 1 second
- apiType, _ := common.ChannelType2APIType(channelType)
-
info := &RelayInfo{
- UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
- UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
- isFirstResponse: true,
- RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
- BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
- RequestURLPath: c.Request.URL.String(),
- ChannelType: channelType,
- ChannelId: channelId,
- TokenId: tokenId,
- TokenKey: tokenKey,
- UserId: userId,
- UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
- UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
- TokenUnlimited: tokenUnlimited,
+ Request: request,
+
+ UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
+ UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
+ UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
+ UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
+ UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
+
+ OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
+
+ TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
+ TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
+ TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
+
+ isFirstResponse: true,
+ RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
+ RequestURLPath: c.Request.URL.String(),
+ IsStream: request.IsStream(c),
+
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
- OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
- UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
- //RecodeModelName: c.GetString("original_model"),
- IsModelMapped: false,
- ApiType: apiType,
- ApiVersion: c.GetString("api_version"),
- ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
- Organization: c.GetString("channel_organization"),
-
- ChannelCreateTime: c.GetInt64("channel_create_time"),
- ParamOverride: paramOverride,
- RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
-
- ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
- ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
}
+
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
- if info.BaseUrl == "" {
- info.BaseUrl = constant.ChannelBaseURLs[channelType]
- }
- if info.ChannelType == constant.ChannelTypeAzure {
- info.ApiVersion = GetAPIVersion(c)
- }
- if info.ChannelType == constant.ChannelTypeVertexAi {
- info.ApiVersion = c.GetString("region")
- }
- if streamSupportedChannels[info.ChannelType] {
- info.SupportStreamOptions = true
- }
-
- channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
- if ok {
- info.ChannelSetting = channelSetting
- }
-
- channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
- if ok {
- info.ChannelOtherSettings = channelOtherSettings
- }
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
if ok {
@@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
return info
}
-func (info *RelayInfo) SetPromptTokens(promptTokens int) {
- info.PromptTokens = promptTokens
+func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
+ switch relayFormat {
+ case types.RelayFormatOpenAI:
+ return GenRelayInfoOpenAI(c, request), nil
+ case types.RelayFormatOpenAIAudio:
+ return GenRelayInfoOpenAIAudio(c, request), nil
+ case types.RelayFormatOpenAIImage:
+ return GenRelayInfoImage(c, request), nil
+ case types.RelayFormatOpenAIRealtime:
+ return GenRelayInfoWs(c, ws), nil
+ case types.RelayFormatClaude:
+ return GenRelayInfoClaude(c, request), nil
+ case types.RelayFormatRerank:
+ if request, ok := request.(*dto.RerankRequest); ok {
+ return GenRelayInfoRerank(c, request), nil
+ }
+ return nil, errors.New("request is not a RerankRequest")
+ case types.RelayFormatGemini:
+ return GenRelayInfoGemini(c, request), nil
+ case types.RelayFormatEmbedding:
+ return GenRelayInfoEmbedding(c, request), nil
+ case types.RelayFormatOpenAIResponses:
+ if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
+ return GenRelayInfoResponses(c, request), nil
+ }
+ return nil, errors.New("request is not a OpenAIResponsesRequest")
+ default:
+ return nil, errors.New("invalid relay format")
+ }
}
-func (info *RelayInfo) SetIsStream(isStream bool) {
- info.IsStream = isStream
+func (info *RelayInfo) SetPromptTokens(promptTokens int) {
+ info.PromptTokens = promptTokens
}
func (info *RelayInfo) SetFirstResponseTime() {
diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go
index 57df5fe3..05dbfa6d 100644
--- a/relay/common_handler/rerank.go
+++ b/relay/common_handler/rerank.go
@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
+ "one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- common.CloseResponseBodyGracefully(resp)
+ service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go
index fef8d2c9..f7906cf9 100644
--- a/relay/embedding_handler.go
+++ b/relay/embedding_handler.go
@@ -8,7 +8,6 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
@@ -16,69 +15,27 @@ import (
"github.com/gin-gonic/gin"
)
-func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
- token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
- return token
-}
+func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
- if embeddingRequest.Input == nil {
- return fmt.Errorf("input is empty")
- }
- if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
- embeddingRequest.Model = "omni-moderation-latest"
- }
- if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
- embeddingRequest.Model = c.Param("model")
- }
- return nil
-}
+ info.InitChannelMeta(c)
-func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
- relayInfo := relaycommon.GenRelayInfoEmbedding(c)
-
- var embeddingRequest *dto.EmbeddingRequest
- err := common.UnmarshalBodyReusable(c, &embeddingRequest)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
}
- err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
+ err := helper.ModelMappedHelper(c, info, embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- promptToken := getEmbeddingPromptToken(*embeddingRequest)
- relayInfo.PromptTokens = promptToken
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
- convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
+ convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
index e0581156..3ebe0884 100644
--- a/relay/gemini_handler.go
+++ b/relay/gemini_handler.go
@@ -2,17 +2,16 @@ package relay
import (
"bytes"
- "errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/relay/channel/gemini"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
- "one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -20,64 +19,6 @@ import (
"github.com/gin-gonic/gin"
)
-func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
- request := &dto.GeminiChatRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if len(request.Contents) == 0 {
- return nil, errors.New("contents is required")
- }
- return request, nil
-}
-
-// 流模式
-// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
-func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
- if c.Query("alt") == "sse" {
- relayInfo.IsStream = true
- }
-
- // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
- // relayInfo.IsStream = true
- // }
-}
-
-func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
- var inputTexts []string
- for _, content := range textRequest.Contents {
- for _, part := range content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
- if len(inputTexts) == 0 {
- return nil, nil
- }
-
- sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
- return sensitiveWords, err
-}
-
-func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
- // 计算输入 token 数量
- var inputTexts []string
- for _, content := range req.Contents {
- for _, part := range content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
-
- inputText := strings.Join(inputTexts, "\n")
- inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
- info.PromptTokens = inputTokens
- return inputTokens
-}
-
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
@@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string {
return modelName
}
-func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
- req, err := getAndValidateGeminiRequest(c)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
+func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
- relayInfo := relaycommon.GenRelayInfoGemini(c)
-
- // 检查 Gemini 流式模式
- checkGeminiStreamMode(c, relayInfo)
-
- if setting.ShouldCheckPromptSensitive() {
- sensitiveWords, err := checkGeminiInputSensitive(req)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
- return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
- }
+ request, ok := info.Request.(*dto.GeminiChatRequest)
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
}
// model mapped 模型映射
- err = helper.ModelMappedHelper(c, relayInfo, req)
+ err := helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens := value.(int)
- relayInfo.SetPromptTokens(promptTokens)
- } else {
- promptTokens := getGeminiInputTokens(req, relayInfo)
- c.Set("prompt_tokens", promptTokens)
- }
-
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- if isNoThinkingRequest(req) {
+ if isNoThinkingRequest(request) {
// check is thinking
- if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
+ if !strings.Contains(info.OriginModelName, "-nothinking") {
// try to get no thinking model price
- noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
+ noThinkingModelName := info.OriginModelName + "-nothinking"
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
if containPrice {
- relayInfo.OriginModelName = noThinkingModelName
- relayInfo.UpstreamModelName = noThinkingModelName
+ info.OriginModelName = noThinkingModelName
+ info.UpstreamModelName = noThinkingModelName
}
}
}
- if req.GenerationConfig.ThinkingConfig == nil {
- gemini.ThinkingAdaptor(req, relayInfo)
+ if request.GenerationConfig.ThinkingConfig == nil {
+ gemini.ThinkingAdaptor(request, info)
}
}
- priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
-
- // pre consume quota
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
// Clean up empty system instruction
- if req.SystemInstructions != nil {
+ if request.SystemInstructions != nil {
hasContent := false
- for _, part := range req.SystemInstructions.Parts {
+ for _, part := range request.SystemInstructions.Parts {
if part.Text != "" {
hasContent = true
break
}
}
if !hasContent {
- req.SystemInstructions = nil
+ request.SystemInstructions = nil
}
}
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
requestBody = bytes.NewReader(body)
} else {
// 使用 ConvertGeminiRequest 转换请求格式
- convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
+ convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- if common.DebugEnabled {
- println("Gemini request body: %s", string(jsonData))
- }
+ logger.LogDebug(c, "Gemini request body: "+string(jsonData))
+
requestBody = bytes.NewReader(jsonData)
}
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
- common.LogError(c, "Do gemini request failed: "+err.Error())
+ logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
-func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
- relayInfo := relaycommon.GenRelayInfoGemini(c)
+func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
- relayInfo.IsGeminiBatchEmbedding = isBatch
+ info.IsGeminiBatchEmbedding = isBatch
- var promptTokens int
var req any
var err error
var inputTexts []string
@@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
}
- promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
- relayInfo.SetPromptTokens(promptTokens)
- c.Set("prompt_tokens", promptTokens)
- err = helper.ModelMappedHelper(c, relayInfo, req)
+ err = helper.ModelMappedHelper(c, info, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
-
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
@@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewReader(jsonData)
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
- common.LogError(c, "Do gemini request failed: "+err.Error())
+ logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
diff --git a/relay/helper/common.go b/relay/helper/common.go
index c8edb798..5075314d 100644
--- a/relay/helper/common.go
+++ b/relay/helper/common.go
@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -100,7 +101,7 @@ func Done(c *gin.Context) {
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
- common.LogError(c, "websocket connection is nil")
+ logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
@@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
- common.LogError(c, "websocket connection is nil")
+ logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
index c1735149..e894e228 100644
--- a/relay/helper/model_mapped.go
+++ b/relay/helper/model_mapped.go
@@ -4,9 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
- common2 "one-api/common"
"one-api/dto"
+ common2 "one-api/logger"
"one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
}
if request != nil {
switch info.RelayFormat {
- case common.RelayFormatGemini:
+ case types.RelayFormatGemini:
// Gemini 模型映射
- case common.RelayFormatClaude:
+ case types.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
- case common.RelayFormatOpenAIResponses:
+ case types.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
- case common.RelayFormatOpenAIAudio:
+ case types.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
- case common.RelayFormatOpenAIImage:
+ case types.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
- case common.RelayFormatRerank:
+ case types.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
- case common.RelayFormatEmbedding:
+ case types.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
diff --git a/relay/helper/price.go b/relay/helper/price.go
index e80578e5..89fc3b66 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -5,35 +5,14 @@ import (
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting/ratio_setting"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
-type GroupRatioInfo struct {
- GroupRatio float64
- GroupSpecialRatio float64
- HasSpecialRatio bool
-}
-
-type PriceData struct {
- ModelPrice float64
- ModelRatio float64
- CompletionRatio float64
- CacheRatio float64
- CacheCreationRatio float64
- ImageRatio float64
- UsePrice bool
- ShouldPreConsumedQuota int
- GroupRatioInfo GroupRatioInfo
-}
-
-func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
-}
-
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
-func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
- groupRatioInfo := GroupRatioInfo{
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
+ groupRatioInfo := types.GroupRatioInfo{
GroupRatio: 1.0, // default ratio
GroupSpecialRatio: -1,
}
@@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR
return groupRatioInfo
}
-func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
groupRatioInfo := HandleGroupRatio(c, info)
@@ -75,8 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheCreationRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
- if maxTokens != 0 {
- preConsumedTokens = promptTokens + maxTokens
+ if meta.MaxTokens != 0 {
+ preConsumedTokens = promptTokens + meta.MaxTokens
}
var success bool
var matchName string
@@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
acceptUnsetRatio = true
}
if !acceptUnsetRatio {
- return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
+ return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
}
}
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
@@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
+ if meta.ImagePriceRatio != 0 {
+ modelPrice = modelPrice * meta.ImagePriceRatio
+ }
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
- priceData := PriceData{
+ priceData := types.PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
@@ -115,38 +97,32 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
-
+ info.PriceData = priceData
return priceData, nil
}
-type PerCallPriceData struct {
- ModelPrice float64
- Quota int
- GroupRatioInfo GroupRatioInfo
-}
-
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
- groupRatioInfo := HandleGroupRatio(c, info)
-
- modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
- // 如果没有配置价格,则使用默认价格
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
- priceData := PerCallPriceData{
- ModelPrice: modelPrice,
- Quota: quota,
- GroupRatioInfo: groupRatioInfo,
- }
- return priceData
-}
+//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+// groupRatioInfo := HandleGroupRatio(c, info)
+//
+// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
+// // 如果没有配置价格,则使用默认价格
+// if !success {
+// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
+// if !ok {
+// modelPrice = 0.1
+// } else {
+// modelPrice = defaultPrice
+// }
+// }
+// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
+// priceData := types.PerCallPriceData{
+// ModelPrice: modelPrice,
+// Quota: quota,
+// GroupRatioInfo: groupRatioInfo,
+// }
+// return priceData
+//}
func ContainPriceOrRatio(modelName string) bool {
_, ok := ratio_setting.GetModelPrice(modelName, false)
diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go
index a5706f95..725d178c 100644
--- a/relay/helper/stream_scanner.go
+++ b/relay/helper/stream_scanner.go
@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
@@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case <-done:
case <-time.After(5 * time.Second):
- common.LogError(c, "timeout waiting for goroutines to exit")
+ logger.LogError(c, "timeout waiting for goroutines to exit")
}
close(stopChan)
@@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer func() {
wg.Done()
if r := recover(); r != nil {
- common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
+ logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
common.SafeSendBool(stopChan, true)
}
if common.DebugEnabled {
@@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case err := <-done:
if err != nil {
- common.LogError(c, "ping data error: "+err.Error())
+ logger.LogError(c, "ping data error: "+err.Error())
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-time.After(10 * time.Second):
- common.LogError(c, "ping data send timeout")
+ logger.LogError(c, "ping data send timeout")
return
case <-ctx.Done():
return
@@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
// 监听客户端断开连接
return
case <-pingTimeout.C:
- common.LogError(c, "ping goroutine max duration reached")
+ logger.LogError(c, "ping goroutine max duration reached")
return
}
}
@@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer func() {
wg.Done()
if r := recover(); r != nil {
- common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
+ logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
}
common.SafeSendBool(stopChan, true)
if common.DebugEnabled {
@@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
}
case <-time.After(10 * time.Second):
- common.LogError(c, "data handler timeout")
+ logger.LogError(c, "data handler timeout")
return
case <-ctx.Done():
return
@@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
if err := scanner.Err(); err != nil {
if err != io.EOF {
- common.LogError(c, "scanner error: "+err.Error())
+ logger.LogError(c, "scanner error: "+err.Error())
}
}
})
@@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case <-ticker.C:
// 超时处理逻辑
- common.LogError(c, "streaming timeout")
+ logger.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
- common.LogInfo(c, "streaming finished")
+ logger.LogInfo(c, "streaming finished")
case <-c.Request.Context().Done():
// 客户端断开连接
- common.LogInfo(c, "client disconnected")
+ logger.LogInfo(c, "client disconnected")
}
}
diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go
new file mode 100644
index 00000000..0bc51774
--- /dev/null
+++ b/relay/helper/valid_request.go
@@ -0,0 +1,301 @@
+package helper
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/logger"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
+ relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+
+ switch format {
+ case types.RelayFormatOpenAI:
+ request, err = GetAndValidateTextRequest(c, relayMode)
+ case types.RelayFormatGemini:
+ request, err = GetAndValidateGeminiRequest(c)
+ case types.RelayFormatClaude:
+ request, err = GetAndValidateClaudeRequest(c)
+ case types.RelayFormatOpenAIResponses:
+ request, err = GetAndValidateResponsesRequest(c)
+
+ case types.RelayFormatOpenAIImage:
+ request, err = GetAndValidOpenAIImageRequest(c, relayMode)
+ case types.RelayFormatEmbedding:
+ request, err = GetAndValidateEmbeddingRequest(c, relayMode)
+ case types.RelayFormatRerank:
+ request, err = GetAndValidateRerankRequest(c)
+ case types.RelayFormatOpenAIAudio:
+ request, err = GetAndValidAudioRequest(c, relayMode)
+ case types.RelayFormatOpenAIRealtime:
+ // nothing to do, no request body
+ default:
+ return nil, fmt.Errorf("unsupported relay format: %s", format)
+ }
+ return request, err
+}
+
+func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
+ audioRequest := &dto.AudioRequest{}
+ err := common.UnmarshalBodyReusable(c, audioRequest)
+ if err != nil {
+ return nil, err
+ }
+ switch relayMode {
+ case relayconstant.RelayModeAudioSpeech:
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ default:
+ err = c.Request.ParseForm()
+ if err != nil {
+ return nil, err
+ }
+ formData := c.Request.PostForm
+ if audioRequest.Model == "" {
+ audioRequest.Model = formData.Get("model")
+ }
+
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ audioRequest.ResponseFormat = formData.Get("response_format")
+ if audioRequest.ResponseFormat == "" {
+ audioRequest.ResponseFormat = "json"
+ }
+ }
+ return audioRequest, nil
+}
+
+func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
+ var rerankRequest *dto.RerankRequest
+ err := common.UnmarshalBodyReusable(c, &rerankRequest)
+ if err != nil {
+ logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ if rerankRequest.Query == "" {
+ return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ if len(rerankRequest.Documents) == 0 {
+ return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ return rerankRequest, nil
+}
+
+func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
+ var embeddingRequest *dto.EmbeddingRequest
+ err := common.UnmarshalBodyReusable(c, &embeddingRequest)
+ if err != nil {
+ logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ if embeddingRequest.Input == nil {
+ return nil, fmt.Errorf("input is empty")
+ }
+ if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+ embeddingRequest.Model = "omni-moderation-latest"
+ }
+ if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+ embeddingRequest.Model = c.Param("model")
+ }
+ return embeddingRequest, nil
+}
+
+func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
+ request := &dto.OpenAIResponsesRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if request.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if request.Input == nil {
+ return nil, errors.New("input is required")
+ }
+ return request, nil
+}
+
+func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
+ imageRequest := &dto.ImageRequest{}
+
+ switch relayMode {
+ case relayconstant.RelayModeImagesEdits:
+ _, err := c.MultipartForm()
+ if err != nil {
+ return nil, err
+ }
+ formData := c.Request.PostForm
+ imageRequest.Prompt = formData.Get("prompt")
+ imageRequest.Model = formData.Get("model")
+ imageRequest.N = uint(common.String2Int(formData.Get("n")))
+ imageRequest.Quality = formData.Get("quality")
+ imageRequest.Size = formData.Get("size")
+
+ if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ }
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+
+ watermark := formData.Has("watermark")
+ if watermark {
+ imageRequest.Watermark = &watermark
+ }
+ default:
+ err := common.UnmarshalBodyReusable(c, imageRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ if imageRequest.Model == "" {
+ imageRequest.Model = "dall-e-3"
+ }
+
+ if strings.Contains(imageRequest.Size, "×") {
+ return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+ }
+
+ // Not "256x256", "512x512", or "1024x1024"
+ if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+ if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+ return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "dall-e-3" {
+ if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+ return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
+ }
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "auto"
+ }
+ }
+
+ if imageRequest.Prompt == "" {
+ return nil, errors.New("prompt is required")
+ }
+
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+ }
+
+ return imageRequest, nil
+}
+
+func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
+ textRequest = &dto.ClaudeRequest{}
+ err = c.ShouldBindJSON(textRequest)
+ if err != nil {
+ return nil, err
+ }
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("field model is required")
+ }
+
+ //if textRequest.Stream {
+ // relayInfo.IsStream = true
+ //}
+
+ return textRequest, nil
+}
+
+func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
+ textRequest := &dto.GeneralOpenAIRequest{}
+ err := common.UnmarshalBodyReusable(c, textRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
+
+ if textRequest.MaxTokens > math.MaxInt32/2 {
+ return nil, errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if textRequest.WebSearchOptions != nil {
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ validSizes := map[string]bool{
+ "high": true,
+ "medium": true,
+ "low": true,
+ }
+ if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
+ return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
+ }
+ } else {
+ textRequest.WebSearchOptions.SearchContextSize = "medium"
+ }
+ }
+ switch relayMode {
+ case relayconstant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return nil, errors.New("field prompt is required")
+ }
+ case relayconstant.RelayModeChatCompletions:
+ if len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ case relayconstant.RelayModeEmbeddings:
+ case relayconstant.RelayModeModerations:
+ if textRequest.Input == nil || textRequest.Input == "" {
+ return nil, errors.New("field input is required")
+ }
+ case relayconstant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return nil, errors.New("field instruction is required")
+ }
+ }
+ return textRequest, nil
+}
+
+func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
+
+ request := &dto.GeminiChatRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if len(request.Contents) == 0 {
+ return nil, errors.New("contents is required")
+ }
+
+ //if c.Query("alt") == "sse" {
+ // relayInfo.IsStream = true
+ //}
+
+ return request, nil
+}
diff --git a/relay/image_handler.go b/relay/image_handler.go
index f0b69699..008a979d 100644
--- a/relay/image_handler.go
+++ b/relay/image_handler.go
@@ -3,19 +3,15 @@ package relay
import (
"bytes"
"encoding/json"
- "errors"
"fmt"
"io"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
- "one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
- "one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -23,183 +19,41 @@ import (
"github.com/gin-gonic/gin"
)
-func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
- imageRequest := &dto.ImageRequest{}
+func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- switch info.RelayMode {
- case relayconstant.RelayModeImagesEdits:
- _, err := c.MultipartForm()
- if err != nil {
- return nil, err
- }
- formData := c.Request.PostForm
- imageRequest.Prompt = formData.Get("prompt")
- imageRequest.Model = formData.Get("model")
- imageRequest.N = common.String2Int(formData.Get("n"))
- imageRequest.Quality = formData.Get("quality")
- imageRequest.Size = formData.Get("size")
+ info.InitChannelMeta(c)
- if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
+ imageRequest, ok := info.Request.(*dto.ImageRequest)
- if info.ApiType == constant.APITypeVolcEngine {
- watermark := formData.Has("watermark")
- imageRequest.Watermark = &watermark
- }
- default:
- err := common.UnmarshalBodyReusable(c, imageRequest)
- if err != nil {
- return nil, err
- }
-
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-3"
- }
-
- if strings.Contains(imageRequest.Size, "×") {
- return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
- }
-
- // Not "256x256", "512x512", or "1024x1024"
- if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
- if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
- return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "dall-e-3" {
- if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
- return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
- }
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "auto"
- }
- }
-
- if imageRequest.Prompt == "" {
- return nil, errors.New("prompt is required")
- }
-
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
}
- if setting.ShouldCheckPromptSensitive() {
- words, err := service.CheckSensitiveInput(imageRequest.Prompt)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
- return nil, err
- }
- }
- return imageRequest, nil
-}
-
-func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
- relayInfo := relaycommon.GenRelayInfoImage(c)
-
- imageRequest, err := getAndValidImageRequest(c, relayInfo)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
+ err := helper.ModelMappedHelper(c, info, imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
- var preConsumedQuota int
- var quota int
- var userQuota int
- if !priceData.UsePrice {
- // modelRatio 16 = modelPrice $0.04
- // per 1 modelRatio = $0.04 / 16
- // priceData.ModelPrice = 0.0025 * priceData.ModelRatio
- preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- } else {
- sizeRatio := 1.0
- qualityRatio := 1.0
-
- if strings.HasPrefix(imageRequest.Model, "dall-e") {
- // Size
- if imageRequest.Size == "256x256" {
- sizeRatio = 0.4
- } else if imageRequest.Size == "512x512" {
- sizeRatio = 0.45
- } else if imageRequest.Size == "1024x1024" {
- sizeRatio = 1
- } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- sizeRatio = 2
- }
-
- if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
- qualityRatio = 2.0
- if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- qualityRatio = 1.5
- }
- }
- }
-
- // reset model price
- priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
- quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
- userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
- }
- if userQuota-quota < 0 {
- return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
- }
- }
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
- convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
+ convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
- if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
+ if info.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
@@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
@@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
if usage.(*dto.Usage).TotalTokens == 0 {
- usage.(*dto.Usage).TotalTokens = imageRequest.N
+ usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
}
if usage.(*dto.Usage).PromptTokens == 0 {
- usage.(*dto.Usage).PromptTokens = imageRequest.N
+ usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
}
+
quality := "standard"
if imageRequest.Quality == "hd" {
quality = "hd"
}
- logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
+ var logContent string
+
+ if len(imageRequest.Size) > 0 {
+ logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+ }
+
+ postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
return nil
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 50d574f3..de750e76 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -2,172 +2,56 @@ package relay
import (
"bytes"
- "errors"
"fmt"
"io"
- "math"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
- "one-api/setting"
"one-api/setting/model_setting"
"one-api/setting/operation_setting"
"one-api/types"
"strings"
"time"
- "github.com/bytedance/gopkg/util/gopool"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
-func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
- textRequest := &dto.GeneralOpenAIRequest{}
- err := common.UnmarshalBodyReusable(c, textRequest)
- if err != nil {
- return nil, err
- }
- if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
- textRequest.Model = c.Param("model")
- }
+func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- if textRequest.MaxTokens > math.MaxInt32/2 {
- return nil, errors.New("max_tokens is invalid")
- }
- if textRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if textRequest.WebSearchOptions != nil {
- if textRequest.WebSearchOptions.SearchContextSize != "" {
- validSizes := map[string]bool{
- "high": true,
- "medium": true,
- "low": true,
- }
- if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
- return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
- }
- } else {
- textRequest.WebSearchOptions.SearchContextSize = "medium"
- }
- }
- switch relayInfo.RelayMode {
- case relayconstant.RelayModeCompletions:
- if textRequest.Prompt == "" {
- return nil, errors.New("field prompt is required")
- }
- case relayconstant.RelayModeChatCompletions:
- if len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- case relayconstant.RelayModeEmbeddings:
- case relayconstant.RelayModeModerations:
- if textRequest.Input == nil || textRequest.Input == "" {
- return nil, errors.New("field input is required")
- }
- case relayconstant.RelayModeEdits:
- if textRequest.Instruction == "" {
- return nil, errors.New("field instruction is required")
- }
- }
- relayInfo.IsStream = textRequest.Stream
- return textRequest, nil
-}
+ info.InitChannelMeta(c)
-func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
- relayInfo := relaycommon.GenRelayInfo(c)
-
- // get & validate textRequest 获取并验证文本请求
- textRequest, err := getAndValidateTextRequest(c, relayInfo)
- if err != nil {
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ if !ok {
+ //return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
}
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
- if setting.ShouldCheckPromptSensitive() {
- words, err := checkRequestSensitive(textRequest, relayInfo)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
- return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
- }
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+ err := helper.ModelMappedHelper(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- // 获取 promptTokens,如果上下文中已经存在,则直接使用
- var promptTokens int
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens = value.(int)
- relayInfo.PromptTokens = promptTokens
- } else {
- promptTokens, err = getPromptTokens(textRequest, relayInfo)
- // count messages token error 计算promptTokens错误
- if err != nil {
- return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
- }
- c.Set("prompt_tokens", promptTokens)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
-
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newApiErr != nil {
- return newApiErr
- }
- defer func() {
- if newApiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
- includeUsage := true
- // 判断用户是否需要返回使用情况
- if textRequest.StreamOptions != nil {
- includeUsage = textRequest.StreamOptions.IncludeUsage
- }
-
- // 如果不支持StreamOptions,将StreamOptions设置为nil
- if !relayInfo.SupportStreamOptions || !textRequest.Stream {
- textRequest.StreamOptions = nil
- } else {
- // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
- if constant.ForceStreamOption {
- textRequest.StreamOptions = &dto.StreamOptions{
- IncludeUsage: true,
- }
- }
- }
-
- relayInfo.ShouldIncludeUsage = includeUsage
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -177,12 +61,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewBuffer(body)
} else {
- convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
+ convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
- if relayInfo.ChannelSetting.SystemPrompt != "" {
+ if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
request := convertedRequest.(*dto.GeneralOpenAIRequest)
containSystemPrompt := false
@@ -196,22 +80,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// 如果没有系统提示,则添加系统提示
systemMessage := dto.Message{
Role: request.GetSystemRoleName(),
- Content: relayInfo.ChannelSetting.SystemPrompt,
+ Content: info.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
- } else if relayInfo.ChannelSetting.SystemPromptOverride {
+ } else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
// 如果有系统提示,且允许覆盖,则拼接到前面
for i, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
if message.IsStringContent() {
- request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+ request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
} else {
contents := message.ParseContent()
contents = append([]dto.MediaContent{
{
Type: dto.ContentTypeText,
- Text: relayInfo.ChannelSetting.SystemPrompt,
+ Text: info.ChannelSetting.SystemPrompt,
},
}, contents...)
request.Messages[i].Content = contents
@@ -228,10 +112,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -240,14 +124,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
+ logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
+
requestBody = bytes.NewBuffer(jsonData)
}
var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -256,125 +139,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if resp != nil {
httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- newApiErr = service.RelayErrorHandler(httpResp, false)
+ newApiErr := service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
}
- usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
if newApiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
- if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
}
return nil
}
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
- var promptTokens int
- var err error
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
- case relayconstant.RelayModeCompletions:
- promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
- case relayconstant.RelayModeModerations:
- promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
- case relayconstant.RelayModeEmbeddings:
- promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
- default:
- err = errors.New("unknown relay mode")
- promptTokens = 0
- }
- info.PromptTokens = promptTokens
- return promptTokens, err
-}
-
-func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
- var err error
- var words []string
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- words, err = service.CheckSensitiveMessages(textRequest.Messages)
- case relayconstant.RelayModeCompletions:
- words, err = service.CheckSensitiveInput(textRequest.Prompt)
- case relayconstant.RelayModeModerations:
- words, err = service.CheckSensitiveInput(textRequest.Input)
- case relayconstant.RelayModeEmbeddings:
- words, err = service.CheckSensitiveInput(textRequest.Input)
- }
- return words, err
-}
-
-// 预扣费并返回用户剩余配额
-func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
- }
- if userQuota <= 0 {
- return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- if userQuota-preConsumedQuota < 0 {
- return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- relayInfo.UserQuota = userQuota
- if userQuota > 100*preConsumedQuota {
- // 用户额度充足,判断令牌额度是否充足
- if !relayInfo.TokenUnlimited {
- // 非无限令牌,判断令牌额度是否充足
- tokenQuota := c.GetInt("token_quota")
- if tokenQuota > 100*preConsumedQuota {
- // 令牌额度充足,信任令牌
- preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
- }
- } else {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
- }
- }
-
- if preConsumedQuota > 0 {
- err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
- if err != nil {
- return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
- if err != nil {
- return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
- }
- }
- return preConsumedQuota, userQuota, nil
-}
-
-func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
- if preConsumedQuota != 0 {
- gopool.Go(func() {
- relayInfoCopy := *relayInfo
-
- err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
- if err != nil {
- common.SysError("error return pre-consumed quota: " + err.Error())
- }
- })
- }
-}
-
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -392,12 +181,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
- completionRatio := priceData.CompletionRatio
- cacheRatio := priceData.CacheRatio
- imageRatio := priceData.ImageRatio
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatioInfo.GroupRatio
- modelPrice := priceData.ModelPrice
+ completionRatio := relayInfo.PriceData.CompletionRatio
+ cacheRatio := relayInfo.PriceData.CacheRatio
+ imageRatio := relayInfo.PriceData.ImageRatio
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -470,7 +259,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
var audioInputQuota decimal.Decimal
var audioInputPrice float64
- if !priceData.UsePrice {
+ if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
var cachedTokensWithRatio decimal.Decimal
@@ -518,7 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
totalTokens := promptTokens + completionTokens
var logContent string
- if !priceData.UsePrice {
+ if !relayInfo.PriceData.UsePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
@@ -530,8 +319,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
if !ratio.IsZero() && quota == 0 {
quota = 1
@@ -540,11 +329,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- quotaDelta := quota - preConsumedQuota
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta != 0 {
- err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
@@ -560,7 +349,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
@@ -604,7 +393,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
- UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 0ccc3b33..ae002d73 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -127,7 +128,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go
index 1e547e2a..85e4f174 100644
--- a/relay/rerank_handler.go
+++ b/relay/rerank_handler.go
@@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token
}
-func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
+func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- var rerankRequest *dto.RerankRequest
- err := common.UnmarshalBodyReusable(c, &rerankRequest)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ rerankRequest, ok := info.Request.(*dto.RerankRequest)
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
}
- relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
-
- if rerankRequest.Query == "" {
- return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
- if len(rerankRequest.Documents) == 0 {
- return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
+ err := helper.ModelMappedHelper(c, info, rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- promptToken := getRerankPromptToken(*rerankRequest)
- relayInfo.PromptTokens = promptToken
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
- convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
+ convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
requestBody = bytes.NewBuffer(jsonData)
}
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
index 65c240b2..cd80da33 100644
--- a/relay/responses_handler.go
+++ b/relay/responses_handler.go
@@ -3,7 +3,6 @@ package relay
import (
"bytes"
"encoding/json"
- "errors"
"fmt"
"io"
"net/http"
@@ -12,7 +11,6 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
- "one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -20,82 +18,24 @@ import (
"github.com/gin-gonic/gin"
)
-func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
- request := &dto.OpenAIResponsesRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if request.Model == "" {
- return nil, errors.New("model is required")
- }
- if len(request.Input) == 0 {
- return nil, errors.New("input is required")
- }
- return request, nil
+func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
-}
-
-func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
- sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
- return sensitiveWords, err
-}
-
-func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
- inputTokens := service.CountTokenInput(req.Input, req.Model)
- info.PromptTokens = inputTokens
- return inputTokens
-}
-
-func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
- req, err := getAndValidateResponsesRequest(c)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ request, ok := info.Request.(*dto.OpenAIResponsesRequest)
+ if !ok {
+ common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request))
}
- relayInfo := relaycommon.GenRelayInfoResponses(c, req)
-
- if setting.ShouldCheckPromptSensitive() {
- sensitiveWords, err := checkInputSensitive(req, relayInfo)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
- return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
- }
- }
-
- err = helper.ModelMappedHelper(c, relayInfo, req)
+ err := helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens := value.(int)
- relayInfo.SetPromptTokens(promptTokens)
- } else {
- promptTokens := getInputTokens(req, relayInfo)
- c.Set("prompt_tokens", promptTokens)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
- if err != nil {
- return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
- }
- // pre consume quota
- preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return newAPIError
- }
- defer func() {
- if newAPIError != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
@@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewBuffer(body)
} else {
- convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
+ convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
- if len(relayInfo.ParamOverride) > 0 {
+ if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
- for key, value := range relayInfo.ParamOverride {
+ for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
@@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
- usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
- if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
}
return nil
}
diff --git a/relay/websocket.go b/relay/websocket.go
index 3715b237..22b681f1 100644
--- a/relay/websocket.go
+++ b/relay/websocket.go
@@ -15,13 +15,6 @@ import (
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
- // get & validate textRequest 获取并验证文本请求
- //realtimeEvent, err := getAndValidateWssRequest(c, ws)
- //if err != nil {
- // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
- // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
- //}
-
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
diff --git a/router/main.go b/router/main.go
index 0d2bfdce..7653f3a5 100644
--- a/router/main.go
+++ b/router/main.go
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/logger"
"os"
"strings"
)
@@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
- common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+ logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
}
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS, indexPage)
diff --git a/router/relay-router.go b/router/relay-router.go
index cd656580..e0f05e97 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -1,11 +1,13 @@
package router
import (
- "github.com/gin-gonic/gin"
"one-api/constant"
"one-api/controller"
"one-api/middleware"
"one-api/relay"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
func SetRelayRouter(router *gin.Engine) {
@@ -62,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
- // WebSocket 路由
+ // WebSocket 路由(统一到 Relay)
wsRouter := relayV1Router.Group("")
wsRouter.Use(middleware.Distribute())
- wsRouter.GET("/realtime", controller.WssRelay)
+ wsRouter.GET("/realtime", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIRealtime)
+ })
}
{
//http router
httpRouter := relayV1Router.Group("")
httpRouter.Use(middleware.Distribute())
- httpRouter.POST("/messages", controller.RelayClaude)
- httpRouter.POST("/completions", controller.Relay)
- httpRouter.POST("/chat/completions", controller.Relay)
- httpRouter.POST("/edits", controller.Relay)
- httpRouter.POST("/images/generations", controller.Relay)
- httpRouter.POST("/images/edits", controller.Relay)
+
+ // claude related routes
+ httpRouter.POST("/messages", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatClaude)
+ })
+
+ // chat related routes
+ httpRouter.POST("/completions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+ httpRouter.POST("/chat/completions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+
+ // response related routes
+ httpRouter.POST("/responses", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIResponses)
+ })
+
+ // image related routes
+ httpRouter.POST("/edits", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+ httpRouter.POST("/images/generations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+ httpRouter.POST("/images/edits", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+
+ // embedding related routes
+ httpRouter.POST("/embeddings", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatEmbedding)
+ })
+
+ // audio related routes
+ httpRouter.POST("/audio/transcriptions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+ httpRouter.POST("/audio/translations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+ httpRouter.POST("/audio/speech", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+
+ // rerank related routes
+ httpRouter.POST("/rerank", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatRerank)
+ })
+
+ // gemini relay routes
+ httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
+ httpRouter.POST("/models/*path", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
+
+ // other relay routes
+ httpRouter.POST("/moderations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+
+ // not implemented
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
- httpRouter.POST("/embeddings", controller.Relay)
- httpRouter.POST("/engines/:model/embeddings", controller.Relay)
- httpRouter.POST("/audio/transcriptions", controller.Relay)
- httpRouter.POST("/audio/translations", controller.Relay)
- httpRouter.POST("/audio/speech", controller.Relay)
- httpRouter.POST("/responses", controller.Relay)
httpRouter.GET("/files", controller.RelayNotImplemented)
httpRouter.POST("/files", controller.RelayNotImplemented)
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
@@ -95,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) {
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
- httpRouter.POST("/moderations", controller.Relay)
- httpRouter.POST("/rerank", controller.Relay)
- httpRouter.POST("/models/*path", controller.Relay)
}
relayMjRouter := router.Group("/mj")
@@ -121,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) {
relayGeminiRouter.Use(middleware.Distribute())
{
// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
- relayGeminiRouter.POST("/models/*path", controller.Relay)
+ relayGeminiRouter.POST("/models/*path", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
}
}
diff --git a/service/cf_worker.go b/service/cf_worker.go
index ae6e1ffe..65f7f133 100644
--- a/service/cf_worker.go
+++ b/service/cf_worker.go
@@ -5,7 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
- "one-api/common"
+ "one-api/logger"
"one-api/setting"
"strings"
)
@@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
- common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+ logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
req := &WorkerRequest{
URL: originUrl,
Key: setting.WorkerValidKey,
}
return DoWorkerRequest(req)
} else {
- common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
+ logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
}
}
diff --git a/service/error.go b/service/error.go
index 9672402d..668731b0 100644
--- a/service/error.go
+++ b/service/error.go
@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/types"
"strconv"
"strings"
@@ -58,7 +59,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
lowerText := strings.ToLower(text)
if !strings.HasPrefix(lowerText, "get file base64 from url") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
+ logger.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
@@ -85,7 +86,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
if err != nil {
return
}
- common.CloseResponseBodyGracefully(resp)
+ CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
err = common.Unmarshal(responseBody, &errResponse)
@@ -138,7 +139,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
+ logger.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
diff --git a/common/http.go b/service/http.go
similarity index 86%
rename from common/http.go
rename to service/http.go
index d2e824ef..357a2e78 100644
--- a/common/http.go
+++ b/service/http.go
@@ -1,10 +1,12 @@
-package common
+package service
import (
"bytes"
"fmt"
"io"
"net/http"
+ "one-api/common"
+ "one-api/logger"
"github.com/gin-gonic/gin"
)
@@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) {
}
err := httpResponse.Body.Close()
if err != nil {
- SysError("failed to close response body: " + err.Error())
+ common.SysError("failed to close response body: " + err.Error())
}
}
@@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
_, err := io.Copy(c.Writer, body)
if err != nil {
- LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
+ logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
}
}
diff --git a/service/image.go b/service/image.go
index 252093f1..957ca041 100644
--- a/service/image.go
+++ b/service/image.go
@@ -8,8 +8,8 @@ import (
"image"
"io"
"net/http"
- "one-api/common"
"one-api/constant"
+ "one-api/logger"
"strings"
"golang.org/x/image/webp"
@@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoDownloadRequest(imageUrl)
if err != nil {
- common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+ logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
}
defer response.Body.Close()
@@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
- common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
+ logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData)))
@@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
config, err = webp.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
}
format = "webp"
}
diff --git a/service/midjourney.go b/service/midjourney.go
index 1fc19682..1d232739 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relayconstant "one-api/relay/constant"
"one-api/setting"
"strconv"
@@ -212,7 +213,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
- common.SysError("do request failed: " + err.Error())
+ logger.SysError("do request failed: " + err.Error())
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
@@ -233,7 +234,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
- common.CloseResponseBodyGracefully(resp)
+ CloseResponseBodyGracefully(resp)
respStr := string(responseBody)
log.Printf("respStr: %s", respStr)
if respStr == "" {
diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go
new file mode 100644
index 00000000..3c4d0e7e
--- /dev/null
+++ b/service/pre_consume_quota.go
@@ -0,0 +1,72 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/logger"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+)
+
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
+ if preConsumedQuota != 0 {
+ gopool.Go(func() {
+ relayInfoCopy := *relayInfo
+
+ err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+ if err != nil {
+ logger.SysError("error return pre-consumed quota: " + err.Error())
+ }
+ })
+ }
+}
+
+// PreConsumeQuota checks if the user has enough quota to pre-consume.
+// It returns the pre-consumed quota if successful, or an error if not.
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+ }
+ if userQuota <= 0 {
+ return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+ if userQuota-preConsumedQuota < 0 {
+ return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+ relayInfo.UserQuota = userQuota
+ if userQuota > 100*preConsumedQuota {
+ // 用户额度充足,判断令牌额度是否充足
+ if !relayInfo.TokenUnlimited {
+ // 非无限令牌,判断令牌额度是否充足
+ tokenQuota := c.GetInt("token_quota")
+ if tokenQuota > 100*preConsumedQuota {
+ // 令牌额度充足,信任令牌
+ preConsumedQuota = 0
+ logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
+ }
+ } else {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ logger.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota)))
+ }
+ }
+
+ if preConsumedQuota > 0 {
+ err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ if err != nil {
+ return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+ err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+ if err != nil {
+ return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+ }
+ }
+ relayInfo.FinalPreConsumedQuota = preConsumedQuota
+ return preConsumedQuota, nil
+}
diff --git a/service/quota.go b/service/quota.go
index 0f618402..d6f49d64 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -8,11 +8,12 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
- "one-api/relay/helper"
"one-api/setting"
"one-api/setting/ratio_setting"
+ "one-api/types"
"strings"
"time"
@@ -129,23 +130,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
- return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
+ return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
- common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
+ logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+ usage *dto.RealtimeUsage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -159,10 +160,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatioInfo.GroupRatio
- modelPrice := priceData.ModelPrice
- usePrice := priceData.UsePrice
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -196,8 +197,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
@@ -208,7 +209,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.InputTokens,
@@ -218,7 +219,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
- UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -226,8 +226,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
})
}
-func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
@@ -235,20 +234,20 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
- completionRatio := priceData.CompletionRatio
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatioInfo.GroupRatio
- modelPrice := priceData.ModelPrice
- cacheRatio := priceData.CacheRatio
+ completionRatio := relayInfo.PriceData.CompletionRatio
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ cacheRatio := relayInfo.PriceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
- cacheCreationRatio := priceData.CacheCreationRatio
+ cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
promptTokens -= cacheTokens
- if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
- maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
+ if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 {
+ maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
if promptTokens >= maybeCacheCreationTokens {
cacheCreationTokens = maybeCacheCreationTokens
}
@@ -257,7 +256,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
calculateQuota := 0.0
- if !priceData.UsePrice {
+ if !relayInfo.PriceData.UsePrice {
calculateQuota = float64(promptTokens)
calculateQuota += float64(cacheTokens) * cacheRatio
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
@@ -282,23 +281,23 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游出错)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- quotaDelta := quota - preConsumedQuota
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
@@ -308,7 +307,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
- UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -317,7 +315,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
-func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
+func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
if priceData.CacheCreationRatio == 1 {
return 0
}
@@ -338,8 +336,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData
(promptCacheCreatePrice - quotaPrice)))
}
-func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -353,10 +350,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatioInfo.GroupRatio
- modelPrice := priceData.ModelPrice
- usePrice := priceData.UsePrice
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -390,18 +387,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- quotaDelta := quota - preConsumedQuota
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
@@ -410,7 +407,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
@@ -420,7 +417,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
- UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -443,7 +439,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
@@ -501,7 +497,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
- err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
+ err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
diff --git a/service/token_counter.go b/service/token_counter.go
index eed5b5ca..ec817182 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -4,18 +4,22 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/tiktoken-go/tokenizer"
- "github.com/tiktoken-go/tokenizer/codec"
"image"
"log"
"math"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
+ "one-api/types"
"strings"
"sync"
"unicode/utf8"
+
+ "github.com/gin-gonic/gin"
+ "github.com/tiktoken-go/tokenizer"
+ "github.com/tiktoken-go/tokenizer/codec"
)
// tokenEncoderMap won't grow after initialization
@@ -28,9 +32,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec)
var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
+ logger.SysLog("initializing token encoders")
defaultTokenEncoder = codec.NewCl100kBase()
- common.SysLog("token encoders initialized")
+ logger.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) tokenizer.Codec {
@@ -72,52 +76,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
return tkm
}
-func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
- if imageUrl == nil {
+func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
+ if fileMeta == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
+
+ // Defaults for 4o/4.1/4.5 family unless overridden below
baseTokens := 85
- if model == "glm-4v" {
+ tileTokens := 170
+
+ // Model classification
+ lowerModel := strings.ToLower(model)
+
+ // Special cases from existing behavior
+ if strings.HasPrefix(lowerModel, "glm-4") {
return 1047, nil
}
- if imageUrl.Detail == "low" {
+
+ // Patch-based models (32x32 patches, capped at 1536, with multiplier)
+ isPatchBased := false
+ multiplier := 1.0
+ switch {
+ case strings.Contains(lowerModel, "gpt-4.1-mini"):
+ isPatchBased = true
+ multiplier = 1.62
+ case strings.Contains(lowerModel, "gpt-4.1-nano"):
+ isPatchBased = true
+ multiplier = 2.46
+ case strings.HasPrefix(lowerModel, "o4-mini"):
+ isPatchBased = true
+ multiplier = 1.72
+ case strings.HasPrefix(lowerModel, "gpt-5-mini"):
+ isPatchBased = true
+ multiplier = 1.62
+ case strings.HasPrefix(lowerModel, "gpt-5-nano"):
+ isPatchBased = true
+ multiplier = 2.46
+ }
+
+ // Tile-based model tokens and bases per doc
+ if !isPatchBased {
+ if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
+ baseTokens = 2833
+ tileTokens = 5667
+ } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
+ baseTokens = 70
+ tileTokens = 140
+ } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
+ baseTokens = 75
+ tileTokens = 150
+ } else if strings.Contains(lowerModel, "computer-use-preview") {
+ baseTokens = 65
+ tileTokens = 129
+ } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
+ baseTokens = 85
+ tileTokens = 170
+ }
+ }
+
+ // Respect existing feature flags/short-circuits
+ if fileMeta.Detail == "low" && !isPatchBased {
return baseTokens, nil
}
if !constant.GetMediaTokenNotStream && !stream {
return 3 * baseTokens, nil
}
-
- // 同步One API的图片计费逻辑
- if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
- imageUrl.Detail = "high"
+ // Normalize detail
+ if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
+ fileMeta.Detail = "high"
}
-
- tileTokens := 170
- if strings.HasPrefix(model, "gpt-4o-mini") {
- tileTokens = 5667
- baseTokens = 2833
- }
- // 是否统计图片token
+ // Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
- if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
- return 3 * baseTokens, nil
- }
+
+ // Decode image to get dimensions
var config image.Config
var err error
var format string
var b64str string
- if strings.HasPrefix(imageUrl.Url, "http") {
- config, format, err = DecodeUrlImageData(imageUrl.Url)
+ if strings.HasPrefix(fileMeta.Data, "http") {
+ config, format, err = DecodeUrlImageData(fileMeta.Data)
} else {
- common.SysLog(fmt.Sprintf("decoding image"))
- config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
+ logger.SysLog(fmt.Sprintf("decoding image"))
+ config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
}
if err != nil {
return 0, err
}
- imageUrl.MimeType = format
+ fileMeta.MimeType = format
if config.Width == 0 || config.Height == 0 {
// not an image
@@ -125,60 +172,144 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
// file type
return 3 * baseTokens, nil
}
- return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
+ return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
}
- shortSide := config.Width
- otherSide := config.Height
- log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
- // 缩放倍数
- scale := 1.0
- if config.Height < shortSide {
- shortSide = config.Height
- otherSide = config.Width
+ width := config.Width
+ height := config.Height
+ log.Printf("format: %s, width: %d, height: %d", format, width, height)
+
+ if isPatchBased {
+ // 32x32 patch-based calculation with 1536 cap and model multiplier
+ ceilDiv := func(a, b int) int { return (a + b - 1) / b }
+ rawPatchesW := ceilDiv(width, 32)
+ rawPatchesH := ceilDiv(height, 32)
+ rawPatches := rawPatchesW * rawPatchesH
+ if rawPatches > 1536 {
+ // scale down
+ area := float64(width * height)
+ r := math.Sqrt(float64(32*32*1536) / area)
+ wScaled := float64(width) * r
+ hScaled := float64(height) * r
+ // adjust to fit whole number of patches after scaling
+ adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
+ adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
+ adj := math.Min(adjW, adjH)
+ if !math.IsNaN(adj) && adj > 0 {
+ r = r * adj
+ }
+ wScaled = float64(width) * r
+ hScaled = float64(height) * r
+ patchesW := math.Ceil(wScaled / 32.0)
+ patchesH := math.Ceil(hScaled / 32.0)
+ imageTokens := int(patchesW * patchesH)
+ if imageTokens > 1536 {
+ imageTokens = 1536
+ }
+ return int(math.Round(float64(imageTokens) * multiplier)), nil
+ }
+ // below cap
+ imageTokens := rawPatches
+ return int(math.Round(float64(imageTokens) * multiplier)), nil
}
- // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
- if shortSide > 768 {
- scale = float64(shortSide) / 768
- shortSide = 768
+ // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
+ // Step 1: fit within 2048x2048 square
+ maxSide := math.Max(float64(width), float64(height))
+ fitScale := 1.0
+ if maxSide > 2048 {
+ fitScale = maxSide / 2048.0
}
- // 将另一边按照相同的比例缩小,向上取整
- otherSide = int(math.Ceil(float64(otherSide) / scale))
- log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
- // 计算图片的token数量(边的长度除以512,向上取整)
- tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
- log.Printf("tiles: %d", tiles)
+ fitW := int(math.Round(float64(width) / fitScale))
+ fitH := int(math.Round(float64(height) / fitScale))
+
+ // Step 2: scale so that shortest side is exactly 768
+ minSide := math.Min(float64(fitW), float64(fitH))
+ if minSide == 0 {
+ return baseTokens, nil
+ }
+ shortScale := 768.0 / minSide
+ finalW := int(math.Round(float64(fitW) * shortScale))
+ finalH := int(math.Round(float64(fitH) * shortScale))
+
+ // Count 512px tiles
+ tilesW := (finalW + 512 - 1) / 512
+ tilesH := (finalH + 512 - 1) / 512
+ tiles := tilesW * tilesH
+
+ if common.DebugEnabled {
+ log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
+ }
+
return tiles*tileTokens + baseTokens, nil
}
-func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
- tkm := 0
- msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
- if err != nil {
- return 0, err
+func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+ if meta == nil {
+ return 0, errors.New("token count meta is nil")
}
- tkm += msgTokens
- if request.Tools != nil {
- openaiTools := request.Tools
- countStr := ""
- for _, tool := range openaiTools {
- countStr = tool.Function.Name
- if tool.Function.Description != "" {
- countStr += tool.Function.Description
- }
- if tool.Function.Parameters != nil {
- countStr += fmt.Sprintf("%v", tool.Function.Parameters)
- }
- }
- toolTokens := CountTokenInput(countStr, request.Model)
- tkm += 8
- tkm += toolTokens
+ model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
+ tkm := CountTextToken(meta.CombineText, model)
+
+ if info.RelayFormat == types.RelayFormatOpenAI {
+ tkm += meta.ToolsCount * 8
+ tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
+ tkm += meta.NameCount * 3
+ tkm += 3
}
+ for _, file := range meta.Files {
+ switch file.FileType {
+ case types.FileTypeImage:
+ if info.RelayFormat == types.RelayFormatGemini {
+ tkm += 240
+ } else {
+ token, err := getImageToken(file, model, info.IsStream)
+ if err != nil {
+ return 0, fmt.Errorf("error counting image token: %v", err)
+ }
+ tkm += token
+ }
+ case types.FileTypeAudio:
+ tkm += 100
+ case types.FileTypeVideo:
+ tkm += 5000
+ case types.FileTypeFile:
+ tkm += 5000
+ }
+ }
+
+ common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
return tkm, nil
}
+//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
+// tkm := 0
+// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
+// if err != nil {
+// return 0, err
+// }
+// tkm += msgTokens
+// if request.Tools != nil {
+// openaiTools := request.Tools
+// countStr := ""
+// for _, tool := range openaiTools {
+// countStr = tool.Function.Name
+// if tool.Function.Description != "" {
+// countStr += tool.Function.Description
+// }
+// if tool.Function.Parameters != nil {
+// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+// }
+// }
+// toolTokens := CountTokenInput(countStr, request.Model)
+// tkm += 8
+// tkm += toolTokens
+// }
+//
+// return tkm, nil
+//}
+
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
@@ -338,58 +469,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}
-func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
- //recover when panic
- tokenEncoder := getTokenEncoder(model)
- // Reference:
- // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- // https://github.com/pkoukk/tiktoken-go/issues/6
- //
- // Every message follows <|start|>{role/name}\n{content}<|end|>\n
- var tokensPerMessage int
- var tokensPerName int
- if model == "gpt-3.5-turbo-0301" {
- tokensPerMessage = 4
- tokensPerName = -1 // If there's a name, the role is omitted
- } else {
- tokensPerMessage = 3
- tokensPerName = 1
- }
- tokenNum := 0
- for _, message := range messages {
- tokenNum += tokensPerMessage
- tokenNum += getTokenNum(tokenEncoder, message.Role)
- if message.Content != nil {
- if message.Name != nil {
- tokenNum += tokensPerName
- tokenNum += getTokenNum(tokenEncoder, *message.Name)
- }
- arrayContent := message.ParseContent()
- for _, m := range arrayContent {
- if m.Type == dto.ContentTypeImageURL {
- imageUrl := m.GetImageMedia()
- imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
- if err != nil {
- return 0, err
- }
- tokenNum += imageTokenNum
- log.Printf("image token num: %d", imageTokenNum)
- } else if m.Type == dto.ContentTypeInputAudio {
- // TODO: 音频token数量计算
- tokenNum += 100
- } else if m.Type == dto.ContentTypeFile {
- tokenNum += 5000
- } else if m.Type == dto.ContentTypeVideoUrl {
- tokenNum += 5000
- } else {
- tokenNum += getTokenNum(tokenEncoder, m.Text)
- }
- }
- }
- }
- tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
- return tokenNum, nil
-}
+//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
+// //recover when panic
+// tokenEncoder := getTokenEncoder(model)
+// // Reference:
+// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+// // https://github.com/pkoukk/tiktoken-go/issues/6
+// //
+// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
+// var tokensPerMessage int
+// var tokensPerName int
+//
+// tokensPerMessage = 3
+// tokensPerName = 1
+//
+// tokenNum := 0
+// for _, message := range messages {
+// tokenNum += tokensPerMessage
+// tokenNum += getTokenNum(tokenEncoder, message.Role)
+// if message.Content != nil {
+// if message.Name != nil {
+// tokenNum += tokensPerName
+// tokenNum += getTokenNum(tokenEncoder, *message.Name)
+// }
+// arrayContent := message.ParseContent()
+// for _, m := range arrayContent {
+// if m.Type == dto.ContentTypeImageURL {
+// imageUrl := m.GetImageMedia()
+// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
+// if err != nil {
+// return 0, err
+// }
+// tokenNum += imageTokenNum
+// log.Printf("image token num: %d", imageTokenNum)
+// } else if m.Type == dto.ContentTypeInputAudio {
+// // TODO: 音频token数量计算
+// tokenNum += 100
+// } else if m.Type == dto.ContentTypeFile {
+// tokenNum += 5000
+// } else if m.Type == dto.ContentTypeVideoUrl {
+// tokenNum += 5000
+// } else {
+// tokenNum += getTokenNum(tokenEncoder, m.Text)
+// }
+// }
+// }
+// }
+// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
+// return tokenNum, nil
+//}
func CountTokenInput(input any, model string) int {
switch v := input.(type) {
diff --git a/service/user_notify.go b/service/user_notify.go
index 96664007..1fcc62d3 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -4,6 +4,7 @@ import (
"fmt"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"strings"
)
@@ -12,7 +13,7 @@ func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
if err != nil {
- common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
}
}
@@ -25,7 +26,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
// Check notification limit
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
- common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
@@ -37,14 +38,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
// check setting email
userEmail = userSetting.NotificationEmail
if userEmail == "" {
- common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
+ logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case dto.NotifyTypeWebhook:
webhookURLStr := userSetting.WebhookUrl
if webhookURLStr == "" {
- common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
+ logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
diff --git a/setting/chat.go b/setting/chat.go
index b97d65ce..b417af28 100644
--- a/setting/chat.go
+++ b/setting/chat.go
@@ -2,7 +2,7 @@ package setting
import (
"encoding/json"
- "one-api/common"
+ "one-api/logger"
)
var Chats = []map[string]string{
@@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error {
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
- common.SysError("error marshalling chats: " + err.Error())
+ logger.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
diff --git a/setting/config/config.go b/setting/config/config.go
index 3af51b14..2e43e0a7 100644
--- a/setting/config/config.go
+++ b/setting/config/config.go
@@ -2,7 +2,7 @@ package config
import (
"encoding/json"
- "one-api/common"
+ "one-api/logger"
"reflect"
"strconv"
"strings"
@@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
// 如果找到配置项,则更新配置
if len(configMap) > 0 {
if err := updateConfigFromMap(config, configMap); err != nil {
- common.SysError("failed to update config " + name + ": " + err.Error())
+ logger.SysError("failed to update config " + name + ": " + err.Error())
continue
}
}
diff --git a/setting/rate_limit.go b/setting/rate_limit.go
index d550b2c3..dcb9fae5 100644
--- a/setting/rate_limit.go
+++ b/setting/rate_limit.go
@@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"math"
- "one-api/common"
+ "one-api/logger"
"sync"
)
@@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string {
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
if err != nil {
- common.SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go
index 3f223bc3..47079850 100644
--- a/setting/ratio_setting/cache_ratio.go
+++ b/setting/ratio_setting/cache_ratio.go
@@ -2,7 +2,7 @@ package ratio_setting
import (
"encoding/json"
- "one-api/common"
+ "one-api/logger"
"sync"
)
@@ -89,7 +89,7 @@ func CacheRatio2JSONString() string {
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
- common.SysError("error marshalling cache ratio: " + err.Error())
+ logger.SysError("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
}
diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go
index 86f4a8d1..c1a666e9 100644
--- a/setting/ratio_setting/group_ratio.go
+++ b/setting/ratio_setting/group_ratio.go
@@ -3,7 +3,7 @@ package ratio_setting
import (
"encoding/json"
"errors"
- "one-api/common"
+ "one-api/logger"
"sync"
)
@@ -48,7 +48,7 @@ func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
- common.SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 {
ratio, ok := groupRatio[name]
if !ok {
- common.SysError("group ratio not found: " + name)
+ logger.SysError("group ratio not found: " + name)
return 1
}
return ratio
@@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupGroupRatio)
if err != nil {
- common.SysError("error marshalling group-group ratio: " + err.Error())
+ logger.SysError("error marshalling group-group ratio: " + err.Error())
}
return string(jsonBytes)
}
diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go
index 4a19895e..ce822800 100644
--- a/setting/ratio_setting/model_ratio.go
+++ b/setting/ratio_setting/model_ratio.go
@@ -320,7 +320,7 @@ func ModelPrice2JSONString() string {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
- jsonBytes, err := json.Marshal(modelPriceMap)
+ jsonBytes, err := common.Marshal(modelPriceMap)
if err != nil {
common.SysError("error marshalling model price: " + err.Error())
}
@@ -359,7 +359,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
- err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
+ err := common.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
@@ -388,7 +388,7 @@ func GetModelRatio(name string) (float64, bool, string) {
}
func DefaultModelRatio2JSONString() string {
- jsonBytes, err := json.Marshal(defaultModelRatio)
+ jsonBytes, err := common.Marshal(defaultModelRatio)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
@@ -420,7 +420,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
- err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
+ err := common.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
@@ -594,7 +594,7 @@ func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
- jsonBytes, err := json.Marshal(modelRatioMap)
+ jsonBytes, err := common.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
@@ -610,7 +610,7 @@ var imageRatioMapMutex sync.RWMutex
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
defer imageRatioMapMutex.RUnlock()
- jsonBytes, err := json.Marshal(imageRatioMap)
+ jsonBytes, err := common.Marshal(imageRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
}
@@ -621,7 +621,7 @@ func UpdateImageRatioByJSONString(jsonStr string) error {
imageRatioMapMutex.Lock()
defer imageRatioMapMutex.Unlock()
imageRatioMap = make(map[string]float64)
- return json.Unmarshal([]byte(jsonStr), &imageRatioMap)
+ return common.Unmarshal([]byte(jsonStr), &imageRatioMap)
}
func GetImageRatio(name string) (float64, bool) {
diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go
index 0ae132d0..bcbe712c 100644
--- a/setting/user_usable_group.go
+++ b/setting/user_usable_group.go
@@ -2,7 +2,7 @@ package setting
import (
"encoding/json"
- "one-api/common"
+ "one-api/logger"
"sync"
)
@@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(userUsableGroups)
if err != nil {
- common.SysError("error marshalling user groups: " + err.Error())
+ logger.SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
diff --git a/types/error.go b/types/error.go
index 5a143612..2cfeb541 100644
--- a/types/error.go
+++ b/types/error.go
@@ -39,12 +39,13 @@ const (
ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected"
// new api error
- ErrorCodeCountTokenFailed ErrorCode = "count_token_failed"
- ErrorCodeModelPriceError ErrorCode = "model_price_error"
- ErrorCodeInvalidApiType ErrorCode = "invalid_api_type"
- ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed"
- ErrorCodeDoRequestFailed ErrorCode = "do_request_failed"
- ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed"
+ ErrorCodeCountTokenFailed ErrorCode = "count_token_failed"
+ ErrorCodeModelPriceError ErrorCode = "model_price_error"
+ ErrorCodeInvalidApiType ErrorCode = "invalid_api_type"
+ ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed"
+ ErrorCodeDoRequestFailed ErrorCode = "do_request_failed"
+ ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed"
+ ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
// channel error
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
diff --git a/types/price_data.go b/types/price_data.go
new file mode 100644
index 00000000..f6a92d7e
--- /dev/null
+++ b/types/price_data.go
@@ -0,0 +1,31 @@
+package types
+
+import "fmt"
+
+type GroupRatioInfo struct {
+ GroupRatio float64
+ GroupSpecialRatio float64
+ HasSpecialRatio bool
+}
+
+type PriceData struct {
+ ModelPrice float64
+ ModelRatio float64
+ CompletionRatio float64
+ CacheRatio float64
+ CacheCreationRatio float64
+ ImageRatio float64
+ UsePrice bool
+ ShouldPreConsumedQuota int
+ GroupRatioInfo GroupRatioInfo
+}
+
+type PerCallPriceData struct {
+ ModelPrice float64
+ Quota int
+ GroupRatioInfo GroupRatioInfo
+}
+
+func (p PriceData) ToSetting() string {
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+}
diff --git a/types/relay_format.go b/types/relay_format.go
new file mode 100644
index 00000000..4c29d649
--- /dev/null
+++ b/types/relay_format.go
@@ -0,0 +1,15 @@
+package types
+
+type RelayFormat string
+
+const (
+ RelayFormatOpenAI RelayFormat = "openai"
+ RelayFormatClaude = "claude"
+ RelayFormatGemini = "gemini"
+ RelayFormatOpenAIResponses = "openai_responses"
+ RelayFormatOpenAIAudio = "openai_audio"
+ RelayFormatOpenAIImage = "openai_image"
+ RelayFormatOpenAIRealtime = "openai_realtime"
+ RelayFormatRerank = "rerank"
+ RelayFormatEmbedding = "embedding"
+)
diff --git a/types/relay_request.go b/types/relay_request.go
new file mode 100644
index 00000000..b9d092f0
--- /dev/null
+++ b/types/relay_request.go
@@ -0,0 +1,27 @@
+package types
+
+type RelayRequest struct {
+ OriginRequest any
+ Format RelayFormat
+ PromptTokenCount int
+}
+
+func (r *RelayRequest) CopyOriginRequest() any {
+ if r.OriginRequest == nil {
+ return nil
+ }
+ switch v := r.OriginRequest.(type) {
+ case *GeneralOpenAIRequest:
+ return v.Copy()
+ case *GeneralClaudeRequest:
+ return v.Copy()
+ case *GeneralGeminiRequest:
+ return v.Copy()
+ case *GeneralRerankRequest:
+ return v.Copy()
+ case *GeneralEmbeddingRequest:
+ return v.Copy()
+ default:
+ return nil
+ }
+}
diff --git a/types/request_meta.go b/types/request_meta.go
new file mode 100644
index 00000000..427bacb9
--- /dev/null
+++ b/types/request_meta.go
@@ -0,0 +1,45 @@
+package types
+
+type FileType string
+
+const (
+ FileTypeImage FileType = "image" // Image file type
+ FileTypeAudio FileType = "audio" // Audio file type
+ FileTypeVideo FileType = "video" // Video file type
+ FileTypeFile FileType = "file" // Generic file type
+)
+
+type TokenType string
+
+const (
+ TokenTypeTextNumber TokenType = "text_number" // Text or number tokens
+ TokenTypeTokenizer TokenType = "tokenizer" // Tokenizer tokens
+ TokenTypeImage TokenType = "image" // Image tokens
+)
+
+type TokenCountMeta struct {
+ TokenType TokenType `json:"token_type,omitempty"` // Type of tokens used in the request
+ CombineText string `json:"combine_text,omitempty"` // Combined text from all messages
+ ToolsCount int `json:"tools_count,omitempty"` // Number of tools used
+ NameCount int `json:"name_count,omitempty"` // Number of names in the request
+ MessagesCount int `json:"messages_count,omitempty"` // Number of messages in the request
+ Files []*FileMeta `json:"files,omitempty"` // List of files, each with type and content
+ MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens allowed in the request
+
+ ImagePriceRatio float64 `json:"image_ratio,omitempty"` // Ratio for image size, if applicable
+ //IsStreaming bool `json:"is_streaming,omitempty"` // Indicates if the request is streaming
+}
+
+type FileMeta struct {
+ FileType
+ MimeType string
+ Data string
+ Detail string
+}
+
+type RequestMeta struct {
+ OriginalModelName string `json:"original_model_name"`
+ UserUsingGroup string `json:"user_using_group"`
+ PromptTokens int `json:"prompt_tokens"`
+ PreConsumedQuota int `json:"pre_consumed_quota"`
+}