From e2037ad756de4219a4c17b4b31e938e78cfe6df0 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 14 Aug 2025 20:05:06 +0800 Subject: [PATCH 01/12] refactor: Introduce pre-consume quota and unify relay handlers This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure. --- common/limiter/limiter.go | 4 +- common/logger.go | 99 --- constant/context_key.go | 2 + controller/channel-billing.go | 5 +- controller/channel-test.go | 11 +- controller/console_migrate.go | 182 ++-- controller/github.go | 5 +- controller/midjourney.go | 37 +- controller/oidc.go | 11 +- controller/ratio_sync.go | 812 +++++++++--------- controller/relay.go | 358 ++++---- controller/task.go | 39 +- controller/task_video.go | 20 +- controller/token.go | 3 +- controller/topup.go | 3 +- controller/twofa.go | 13 +- controller/user.go | 9 +- dto/audio.go | 18 + dto/claude.go | 128 ++- dto/embedding.go | 28 +- dto/gemini.go | 65 +- dto/{dalle.go => openai_image.go} | 46 +- dto/openai_request.go | 295 ++++++- dto/request_common.go | 11 + dto/rerank.go | 27 + logger/logger.go | 115 +++ main.go | 37 +- middleware/recover.go | 6 +- middleware/turnstile-check.go | 5 +- middleware/utils.go | 5 +- model/ability.go | 9 +- model/channel.go | 27 +- model/channel_cache.go | 5 +- model/log.go | 12 +- model/main.go | 31 +- model/option.go | 5 +- model/pricing.go | 3 +- model/redemption.go | 3 +- model/token.go | 19 +- model/topup.go | 3 +- model/twofa.go | 9 +- model/usedata.go | 9 +- model/user.go | 33 +- model/user_cache.go | 5 +- model/utils.go | 9 +- relay/audio_handler.go | 93 +- relay/channel/ali/image.go | 12 +- relay/channel/ali/rerank.go | 4 +- relay/channel/ali/text.go | 12 +- relay/channel/api_request.go | 3 +- relay/channel/baidu/relay-baidu.go | 11 +- relay/channel/claude/relay-claude.go | 15 +- relay/channel/cloudflare/relay_cloudflare.go | 16 +- relay/channel/cohere/relay-cohere.go | 9 +- relay/channel/coze/relay-coze.go | 13 +- relay/channel/dify/relay-dify.go | 25 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/gemini/relay-gemini-native.go | 13 +- relay/channel/gemini/relay-gemini.go | 17 +- relay/channel/jimeng/image.go | 4 +- relay/channel/jimeng/sign.go | 4 +- relay/channel/mokaai/relay-mokaai.go | 5 +- relay/channel/ollama/relay-ollama.go | 4 +- relay/channel/openai/helper.go | 17 +- relay/channel/openai/relay-openai.go | 41 +- relay/channel/openai/relay_responses.go | 7 +- relay/channel/palm/relay-palm.go | 15 +- .../channel/siliconflow/relay-siliconflow.go | 6 +- relay/channel/task/suno/adaptor.go | 3 +- relay/channel/tencent/relay-tencent.go | 13 +- relay/channel/xai/text.go | 11 +- relay/channel/xunfei/relay-xunfei.go | 11 +- relay/channel/zhipu/relay-zhipu.go | 14 +- relay/{relay-mj.go => chat_handler.go} | 9 +- relay/claude_handler.go | 91 +- relay/common/relay_info.go | 321 +++---- relay/common_handler/rerank.go | 3 +- relay/embedding_handler.go | 69 +- relay/gemini_handler.go | 203 +---- relay/helper/common.go | 5 +- relay/helper/model_mapped.go | 17 +- relay/helper/price.go | 90 +- relay/helper/stream_scanner.go | 23 +- relay/helper/valid_request.go | 301 +++++++ relay/image_handler.go | 194 +---- relay/relay-text.go | 300 +------ relay/relay_task.go | 3 +- relay/rerank_handler.go | 59 +- relay/responses_handler.go | 94 +- relay/websocket.go | 7 - router/main.go | 3 +- router/relay-router.go | 94 +- service/cf_worker.go | 6 +- service/error.go | 7 +- {common => service}/http.go | 8 +- service/image.go | 10 +- service/midjourney.go | 5 +- service/pre_consume_quota.go | 72 ++ service/quota.go | 90 +- service/token_counter.go | 366 +++++--- service/user_notify.go | 9 +- setting/chat.go | 4 +- setting/config/config.go | 4 +- setting/rate_limit.go | 4 +- setting/ratio_setting/cache_ratio.go | 4 +- setting/ratio_setting/group_ratio.go | 8 +- setting/ratio_setting/model_ratio.go | 14 +- setting/user_usable_group.go | 4 +- types/error.go | 13 +- types/price_data.go | 31 + types/relay_format.go | 15 + types/relay_request.go | 27 + types/request_meta.go | 45 + 113 files changed, 3095 insertions(+), 2518 deletions(-) rename dto/{dalle.go => openai_image.go} (51%) create mode 100644 dto/request_common.go create mode 100644 logger/logger.go rename relay/{relay-mj.go => chat_handler.go} (98%) create mode 100644 relay/helper/valid_request.go rename {common => service}/http.go (86%) create mode 100644 service/pre_consume_quota.go create mode 100644 types/price_data.go create mode 100644 types/relay_format.go create mode 100644 types/relay_request.go create mode 100644 types/request_meta.go diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index ef5d1935..fcfcb0c3 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -5,7 +5,7 @@ import ( _ "embed" "fmt" "github.com/go-redis/redis/v8" - "one-api/common" + "one-api/logger" "sync" ) @@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter { // 预加载脚本 limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { - common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) + logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } instance = &RedisLimiter{ client: r, diff --git a/common/logger.go b/common/logger.go index 0f6dc3c3..478015f0 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,52 +1,12 @@ package common import ( - "context" - "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" - "io" - "log" "os" - "path/filepath" - "sync" "time" ) -const ( - loggerINFO = "INFO" - loggerWarn = "WARN" - loggerError = "ERR" -) - -const maxLogCount = 1000000 - -var logCount int -var setupLogLock sync.Mutex -var setupLogWorking bool - -func SetupLogger() { - if *LogDir != "" { - ok := setupLogLock.TryLock() - if !ok { - log.Println("setup log is already working") - return - } - defer func() { - setupLogLock.Unlock() - setupLogWorking = false - }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) - fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) - } -} - func SysLog(s string) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) @@ -57,67 +17,8 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } -func LogInfo(ctx context.Context, msg string) { - logHelper(ctx, loggerINFO, msg) -} - -func LogWarn(ctx context.Context, msg string) { - logHelper(ctx, loggerWarn, msg) -} - -func LogError(ctx context.Context, msg string) { - logHelper(ctx, loggerError, msg) -} - -func logHelper(ctx context.Context, level string, msg string) { - writer := gin.DefaultErrorWriter - if level == loggerINFO { - writer = gin.DefaultWriter - } - id := ctx.Value(RequestIdKey) - if id == nil { - id = "SYSTEM" - } - now := time.Now() - _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - logCount++ // we don't need accurate count, so no lock here - if logCount > maxLogCount && !setupLogWorking { - logCount = 0 - setupLogWorking = true - gopool.Go(func() { - SetupLogger() - }) - } -} - func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} - -func FormatQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d", quota) - } -} - -// LogJson 仅供测试使用 only for test -func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) - if err != nil { - LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) - return - } - LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr))) -} diff --git a/constant/context_key.go b/constant/context_key.go index b82b19e7..569a0373 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,6 +3,8 @@ package constant type ContextKey string const ( + ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 5152e060..bbf0f97a 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -485,8 +486,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("updating all channels") + logger.SysLog("updating all channels") _ = updateAllChannelsBalance() - common.SysLog("channels update done") + logger.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index 026a863b..ec2e6226 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -13,6 +13,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" @@ -159,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { // 创建一个用于日志的 info 副本,移除 ApiKey logInfo := *info logInfo.ApiKey = "" - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) if err != nil { @@ -279,7 +280,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Group: info.UsingGroup, Other: other, }) - common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return testResult{ context: c, localErr: nil, @@ -461,13 +462,13 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { if frequency <= 0 { - common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") + logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") return } for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d25f199b..d21f5e21 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -3,101 +3,101 @@ package controller import ( - "encoding/json" - "net/http" - "one-api/common" - "one-api/model" - "github.com/gin-gonic/gin" + "encoding/json" + "github.com/gin-gonic/gin" + "net/http" + "one-api/logger" + "one-api/model" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* func MigrateConsoleSetting(c *gin.Context) { - // 读取全部 option - opts, err := model.AllOption() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) - return - } - // 建立 map - valMap := map[string]string{} - for _, o := range opts { - valMap[o.Key] = o.Value - } + // 读取全部 option + opts, err := model.AllOption() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + // 建立 map + valMap := map[string]string{} + for _, o := range opts { + valMap[o.Key] = o.Value + } - // 处理 APIInfo - if v := valMap["ApiInfo"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - if len(arr) > 50 { - arr = arr[:50] - } - bytes, _ := json.Marshal(arr) - model.UpdateOption("console_setting.api_info", string(bytes)) - } - model.UpdateOption("ApiInfo", "") - } - // Announcements 直接搬 - if v := valMap["Announcements"]; v != "" { - model.UpdateOption("console_setting.announcements", v) - model.UpdateOption("Announcements", "") - } - // FAQ 转换 - if v := valMap["FAQ"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - out := []map[string]interface{}{} - for _, item := range arr { - q, _ := item["question"].(string) - if q == "" { - q, _ = item["title"].(string) - } - a, _ := item["answer"].(string) - if a == "" { - a, _ = item["content"].(string) - } - if q != "" && a != "" { - out = append(out, map[string]interface{}{"question": q, "answer": a}) - } - } - if len(out) > 50 { - out = out[:50] - } - bytes, _ := json.Marshal(out) - model.UpdateOption("console_setting.faq", string(bytes)) - } - model.UpdateOption("FAQ", "") - } - // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) - url := valMap["UptimeKumaUrl"] - slug := valMap["UptimeKumaSlug"] - if url != "" && slug != "" { - // 仅当同时存在 URL 与 Slug 时才进行迁移 - groups := []map[string]interface{}{ - { - "id": 1, - "categoryName": "old", - "url": url, - "slug": slug, - "description": "", - }, - } - bytes, _ := json.Marshal(groups) - model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) - } - // 清空旧键内容 - if url != "" { - model.UpdateOption("UptimeKumaUrl", "") - } - if slug != "" { - model.UpdateOption("UptimeKumaSlug", "") - } + // 处理 APIInfo + if v := valMap["ApiInfo"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + if len(arr) > 50 { + arr = arr[:50] + } + bytes, _ := json.Marshal(arr) + model.UpdateOption("console_setting.api_info", string(bytes)) + } + model.UpdateOption("ApiInfo", "") + } + // Announcements 直接搬 + if v := valMap["Announcements"]; v != "" { + model.UpdateOption("console_setting.announcements", v) + model.UpdateOption("Announcements", "") + } + // FAQ 转换 + if v := valMap["FAQ"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + out := []map[string]interface{}{} + for _, item := range arr { + q, _ := item["question"].(string) + if q == "" { + q, _ = item["title"].(string) + } + a, _ := item["answer"].(string) + if a == "" { + a, _ = item["content"].(string) + } + if q != "" && a != "" { + out = append(out, map[string]interface{}{"question": q, "answer": a}) + } + } + if len(out) > 50 { + out = out[:50] + } + bytes, _ := json.Marshal(out) + model.UpdateOption("console_setting.faq", string(bytes)) + } + model.UpdateOption("FAQ", "") + } + // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) + url := valMap["UptimeKumaUrl"] + slug := valMap["UptimeKumaSlug"] + if url != "" && slug != "" { + // 仅当同时存在 URL 与 Slug 时才进行迁移 + groups := []map[string]interface{}{ + { + "id": 1, + "categoryName": "old", + "url": url, + "slug": slug, + "description": "", + }, + } + bytes, _ := json.Marshal(groups) + model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) + } + // 清空旧键内容 + if url != "" { + model.UpdateOption("UptimeKumaUrl", "") + } + if slug != "" { + model.UpdateOption("UptimeKumaSlug", "") + } - // 删除旧键记录 - oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} - model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) + // 删除旧键记录 + oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} + model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) - // 重新加载 OptionMap - model.InitOptionMap() - common.SysLog("console setting migrated") - c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) -} \ No newline at end of file + // 重新加载 OptionMap + model.InitOptionMap() + logger.SysLog("console setting migrated") + c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) +} diff --git a/controller/github.go b/controller/github.go index 881d6dc1..0715a8fe 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" "time" @@ -47,7 +48,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -63,7 +64,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/midjourney.go b/controller/midjourney.go index 30a5a09a..a67d39c2 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() { continue } - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) @@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() { } for channelId, taskIds := range taskChannelM { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } midjourneyChannel, err := model.CacheGetChannel(channelId) if err != nil { - common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) + logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) err := model.MjBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) } continue } @@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() { }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 @@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() { req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := service.GetHttpClient().Do(req) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -147,12 +148,12 @@ func UpdateMidjourneyTaskBulk() { } // 映射 VideoUrl task.VideoUrl = responseItem.VideoUrl - + // 映射 VideoUrls - 将数组序列化为 JSON 字符串 if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) if err != nil { - common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) + logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) task.VideoUrls = "[]" // 失败时设置为空数组 } else { task.VideoUrls = string(videoUrlsStr) @@ -160,10 +161,10 @@ func UpdateMidjourneyTaskBulk() { } else { task.VideoUrls = "" // 空值时清空字段 } - + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { - common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" if task.Quota != 0 { shouldReturnQuota = true @@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() { } err = task.Update() if err != nil { - common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } diff --git a/controller/oidc.go b/controller/oidc.go index df8ea1c4..1e3435a8 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/setting" "one-api/setting/system_setting" @@ -58,7 +59,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res.Body.Close() @@ -69,7 +70,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - common.SysError("OIDC 获取 Token 失败,请检查设置!") + logger.SysError("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -80,12 +81,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - common.SysError("OIDC 获取用户信息失败!请检查设置!") + logger.SysError("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -95,7 +96,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - common.SysError("OIDC 获取用户信息为空!请检查设置!") + logger.SysError("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 0453870d..6fba0aac 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,474 +1,474 @@ package controller import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - "time" + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/logger" + "strings" + "sync" + "time" - "one-api/common" - "one-api/dto" - "one-api/model" - "one-api/setting/ratio_setting" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const ( - defaultTimeoutSeconds = 10 - defaultEndpoint = "/api/ratio_config" - maxConcurrentFetches = 8 + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 ) var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} type upstreamResult struct { - Name string `json:"name"` - Data map[string]any `json:"data,omitempty"` - Err string `json:"err,omitempty"` + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` } func FetchUpstreamRatios(c *gin.Context) { - var req dto.UpstreamRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) - return - } + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } - if req.Timeout <= 0 { - req.Timeout = defaultTimeoutSeconds - } + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } - var upstreams []dto.UpstreamDTO + var upstreams []dto.UpstreamDTO - if len(req.Upstreams) > 0 { - for _, u := range req.Upstreams { - if strings.HasPrefix(u.BaseURL, "http") { - if u.Endpoint == "" { - u.Endpoint = defaultEndpoint - } - u.BaseURL = strings.TrimRight(u.BaseURL, "/") - upstreams = append(upstreams, u) - } - } - } else if len(req.ChannelIDs) > 0 { - intIds := make([]int, 0, len(req.ChannelIDs)) - for _, id64 := range req.ChannelIDs { - intIds = append(intIds, int(id64)) - } - dbChannels, err := model.GetChannelsByIds(intIds) - if err != nil { - common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) - return - } - for _, ch := range dbChannels { - if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { - upstreams = append(upstreams, dto.UpstreamDTO{ - ID: ch.Id, - Name: ch.Name, - BaseURL: strings.TrimRight(base, "/"), - Endpoint: "", - }) - } - } - } + if len(req.Upstreams) > 0 { + for _, u := range req.Upstreams { + if strings.HasPrefix(u.BaseURL, "http") { + if u.Endpoint == "" { + u.Endpoint = defaultEndpoint + } + u.BaseURL = strings.TrimRight(u.BaseURL, "/") + upstreams = append(upstreams, u) + } + } + } else if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + ID: ch.Id, + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } - if len(upstreams) == 0 { - c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) - return - } + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } - var wg sync.WaitGroup - ch := make(chan upstreamResult, len(upstreams)) + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) - sem := make(chan struct{}, maxConcurrentFetches) + sem := make(chan struct{}, maxConcurrentFetches) - client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} - for _, chn := range upstreams { - wg.Add(1) - go func(chItem dto.UpstreamDTO) { - defer wg.Done() + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() - sem <- struct{}{} - defer func() { <-sem }() + sem <- struct{}{} + defer func() { <-sem }() - endpoint := chItem.Endpoint - if endpoint == "" { - endpoint = defaultEndpoint - } else if !strings.HasPrefix(endpoint, "/") { - endpoint = "/" + endpoint - } - fullURL := chItem.BaseURL + endpoint + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL := chItem.BaseURL + endpoint - uniqueName := chItem.Name - if chItem.ID != 0 { - uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) - } + uniqueName := chItem.Name + if chItem.ID != 0 { + uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) + } - ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) - if err != nil { - common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - resp, err := client.Do(httpReq) - if err != nil { - common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) - ch <- upstreamResult{Name: uniqueName, Err: resp.Status} - return - } - // 兼容两种上游接口格式: - // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price - // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 - var body struct { - Success bool `json:"success"` - Data json.RawMessage `json:"data"` - Message string `json:"message"` - } + resp, err := client.Do(httpReq) + if err != nil { + logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) + ch <- upstreamResult{Name: uniqueName, Err: resp.Status} + return + } + // 兼容两种上游接口格式: + // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price + // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 + var body struct { + Success bool `json:"success"` + Data json.RawMessage `json:"data"` + Message string `json:"message"` + } - if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { - common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - if !body.Success { - ch <- upstreamResult{Name: uniqueName, Err: body.Message} - return - } + if !body.Success { + ch <- upstreamResult{Name: uniqueName, Err: body.Message} + return + } - // 尝试按 type1 解析 - var type1Data map[string]any - if err := json.Unmarshal(body.Data, &type1Data); err == nil { - // 如果包含至少一个 ratioTypes 字段,则认为是 type1 - isType1 := false - for _, rt := range ratioTypes { - if _, ok := type1Data[rt]; ok { - isType1 = true - break - } - } - if isType1 { - ch <- upstreamResult{Name: uniqueName, Data: type1Data} - return - } - } + // 尝试按 type1 解析 + var type1Data map[string]any + if err := json.Unmarshal(body.Data, &type1Data); err == nil { + // 如果包含至少一个 ratioTypes 字段,则认为是 type1 + isType1 := false + for _, rt := range ratioTypes { + if _, ok := type1Data[rt]; ok { + isType1 = true + break + } + } + if isType1 { + ch <- upstreamResult{Name: uniqueName, Data: type1Data} + return + } + } - // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 - var pricingItems []struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - CompletionRatio float64 `json:"completion_ratio"` - } - if err := json.Unmarshal(body.Data, &pricingItems); err != nil { - common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} - return - } + // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 + var pricingItems []struct { + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + CompletionRatio float64 `json:"completion_ratio"` + } + if err := json.Unmarshal(body.Data, &pricingItems); err != nil { + logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} + return + } - modelRatioMap := make(map[string]float64) - completionRatioMap := make(map[string]float64) - modelPriceMap := make(map[string]float64) + modelRatioMap := make(map[string]float64) + completionRatioMap := make(map[string]float64) + modelPriceMap := make(map[string]float64) - for _, item := range pricingItems { - if item.QuotaType == 1 { - modelPriceMap[item.ModelName] = item.ModelPrice - } else { - modelRatioMap[item.ModelName] = item.ModelRatio - // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 - completionRatioMap[item.ModelName] = item.CompletionRatio - } - } + for _, item := range pricingItems { + if item.QuotaType == 1 { + modelPriceMap[item.ModelName] = item.ModelPrice + } else { + modelRatioMap[item.ModelName] = item.ModelRatio + // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 + completionRatioMap[item.ModelName] = item.CompletionRatio + } + } - converted := make(map[string]any) + converted := make(map[string]any) - if len(modelRatioMap) > 0 { - ratioAny := make(map[string]any, len(modelRatioMap)) - for k, v := range modelRatioMap { - ratioAny[k] = v - } - converted["model_ratio"] = ratioAny - } + if len(modelRatioMap) > 0 { + ratioAny := make(map[string]any, len(modelRatioMap)) + for k, v := range modelRatioMap { + ratioAny[k] = v + } + converted["model_ratio"] = ratioAny + } - if len(completionRatioMap) > 0 { - compAny := make(map[string]any, len(completionRatioMap)) - for k, v := range completionRatioMap { - compAny[k] = v - } - converted["completion_ratio"] = compAny - } + if len(completionRatioMap) > 0 { + compAny := make(map[string]any, len(completionRatioMap)) + for k, v := range completionRatioMap { + compAny[k] = v + } + converted["completion_ratio"] = compAny + } - if len(modelPriceMap) > 0 { - priceAny := make(map[string]any, len(modelPriceMap)) - for k, v := range modelPriceMap { - priceAny[k] = v - } - converted["model_price"] = priceAny - } + if len(modelPriceMap) > 0 { + priceAny := make(map[string]any, len(modelPriceMap)) + for k, v := range modelPriceMap { + priceAny[k] = v + } + converted["model_price"] = priceAny + } - ch <- upstreamResult{Name: uniqueName, Data: converted} - }(chn) - } + ch <- upstreamResult{Name: uniqueName, Data: converted} + }(chn) + } - wg.Wait() - close(ch) + wg.Wait() + close(ch) - localData := ratio_setting.GetExposedData() + localData := ratio_setting.GetExposedData() - var testResults []dto.TestResult - var successfulChannels []struct { - name string - data map[string]any - } + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } - for r := range ch { - if r.Err != "" { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "error", - Error: r.Err, - }) - } else { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "success", - }) - successfulChannels = append(successfulChannels, struct { - name string - data map[string]any - }{name: r.Name, data: r.Data}) - } - } + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } - differences := buildDifferences(localData, successfulChannels) + differences := buildDifferences(localData, successfulChannels) - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "differences": differences, - "test_results": testResults, - }, - }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) } func buildDifferences(localData map[string]any, successfulChannels []struct { - name string - data map[string]any + name string + data map[string]any }) map[string]map[string]dto.DifferenceItem { - differences := make(map[string]map[string]dto.DifferenceItem) + differences := make(map[string]map[string]dto.DifferenceItem) - allModels := make(map[string]struct{}) - - for _, ratioType := range ratioTypes { - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - for modelName := range localRatio { - allModels[modelName] = struct{}{} - } - } - } - } - - for _, channel := range successfulChannels { - for _, ratioType := range ratioTypes { - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - for modelName := range upstreamRatio { - allModels[modelName] = struct{}{} - } - } - } - } + allModels := make(map[string]struct{}) - confidenceMap := make(map[string]map[string]bool) - - // 预处理阶段:检查pricing接口的可信度 - for _, channel := range successfulChannels { - confidenceMap[channel.name] = make(map[string]bool) - - modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) - completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - - if hasModelRatio && hasCompletionRatio { - // 遍历所有模型,检查是否满足不可信条件 - for modelName := range allModels { - // 默认为可信 - confidenceMap[channel.name][modelName] = true - - // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 - if modelRatioVal, ok := modelRatios[modelName]; ok { - if completionRatioVal, ok := completionRatios[modelName]; ok { - // 转换为float64进行比较 - if modelRatioFloat, ok := modelRatioVal.(float64); ok { - if completionRatioFloat, ok := completionRatioVal.(float64); ok { - if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { - confidenceMap[channel.name][modelName] = false - } - } - } - } - } - } - } else { - // 如果不是从pricing接口获取的数据,则全部标记为可信 - for modelName := range allModels { - confidenceMap[channel.name][modelName] = true - } - } - } + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } - for modelName := range allModels { - for _, ratioType := range ratioTypes { - var localValue interface{} = nil - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - if val, exists := localRatio[modelName]; exists { - localValue = val - } - } - } + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } - upstreamValues := make(map[string]interface{}) - confidenceValues := make(map[string]bool) - hasUpstreamValue := false - hasDifference := false + confidenceMap := make(map[string]map[string]bool) - for _, channel := range successfulChannels { - var upstreamValue interface{} = nil - - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - if val, exists := upstreamRatio[modelName]; exists { - upstreamValue = val - hasUpstreamValue = true - - if localValue != nil && localValue != val { - hasDifference = true - } else if localValue == val { - upstreamValue = "same" - } - } - } - if upstreamValue == nil && localValue == nil { - upstreamValue = "same" - } - - if localValue == nil && upstreamValue != nil && upstreamValue != "same" { - hasDifference = true - } - - upstreamValues[channel.name] = upstreamValue - - confidenceValues[channel.name] = confidenceMap[channel.name][modelName] - } + // 预处理阶段:检查pricing接口的可信度 + for _, channel := range successfulChannels { + confidenceMap[channel.name] = make(map[string]bool) - shouldInclude := false - - if localValue != nil { - if hasDifference { - shouldInclude = true - } - } else { - if hasUpstreamValue { - shouldInclude = true - } - } + modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) + completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - if shouldInclude { - if differences[modelName] == nil { - differences[modelName] = make(map[string]dto.DifferenceItem) - } - differences[modelName][ratioType] = dto.DifferenceItem{ - Current: localValue, - Upstreams: upstreamValues, - Confidence: confidenceValues, - } - } - } - } + if hasModelRatio && hasCompletionRatio { + // 遍历所有模型,检查是否满足不可信条件 + for modelName := range allModels { + // 默认为可信 + confidenceMap[channel.name][modelName] = true - channelHasDiff := make(map[string]bool) - for _, ratioMap := range differences { - for _, item := range ratioMap { - for chName, val := range item.Upstreams { - if val != nil && val != "same" { - channelHasDiff[chName] = true - } - } - } - } + // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 + if modelRatioVal, ok := modelRatios[modelName]; ok { + if completionRatioVal, ok := completionRatios[modelName]; ok { + // 转换为float64进行比较 + if modelRatioFloat, ok := modelRatioVal.(float64); ok { + if completionRatioFloat, ok := completionRatioVal.(float64); ok { + if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { + confidenceMap[channel.name][modelName] = false + } + } + } + } + } + } + } else { + // 如果不是从pricing接口获取的数据,则全部标记为可信 + for modelName := range allModels { + confidenceMap[channel.name][modelName] = true + } + } + } - for modelName, ratioMap := range differences { - for ratioType, item := range ratioMap { - for chName := range item.Upstreams { - if !channelHasDiff[chName] { - delete(item.Upstreams, chName) - delete(item.Confidence, chName) - } - } + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } - allSame := true - for _, v := range item.Upstreams { - if v != "same" { - allSame = false - break - } - } - if len(item.Upstreams) == 0 || allSame { - delete(ratioMap, ratioType) - } else { - differences[modelName][ratioType] = item - } - } + upstreamValues := make(map[string]interface{}) + confidenceValues := make(map[string]bool) + hasUpstreamValue := false + hasDifference := false - if len(ratioMap) == 0 { - delete(differences, modelName) - } - } + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil - return differences + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + + confidenceValues[channel.name] = confidenceMap[channel.name][modelName] + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + Confidence: confidenceValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + delete(item.Confidence, chName) + } + } + + allSame := true + for _, v := range item.Upstreams { + if v != "same" { + allSame = false + break + } + } + if len(item.Upstreams) == 0 || allSame { + delete(ratioMap, ratioType) + } else { + differences[modelName][ratioType] = item + } + } + + if len(ratioMap) == 0 { + delete(differences, modelName) + } + } + + return differences } func GetSyncableChannels(c *gin.Context) { - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } - var syncableChannels []dto.SyncableChannel - for _, channel := range channels { - if channel.GetBaseURL() != "" { - syncableChannels = append(syncableChannels, dto.SyncableChannel{ - ID: channel.Id, - Name: channel.Name, - BaseURL: channel.GetBaseURL(), - Status: channel.Status, - }) - } - } + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": syncableChannels, - }) -} \ No newline at end of file + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} diff --git a/controller/relay.go b/controller/relay.go index d235f550..583ac036 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,21 +2,22 @@ package controller import ( "bytes" - "errors" "fmt" "io" "log" "net/http" "one-api/common" "one-api/constant" - constant2 "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" + relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting" "one-api/types" "strings" @@ -24,81 +25,196 @@ import ( "github.com/gorilla/websocket" ) -func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { +func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { var err *types.NewAPIError - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - err = relay.ImageHelper(c) + err = relay.ImageHelper(c, info) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c) + err = relay.AudioHelper(c, info) case relayconstant.RelayModeRerank: - err = relay.RerankHelper(c, relayMode) + err = relay.RerankHelper(c, info) case relayconstant.RelayModeEmbeddings: - err = relay.EmbeddingHelper(c) + err = relay.EmbeddingHelper(c, info) case relayconstant.RelayModeResponses: - err = relay.ResponsesHelper(c) - case relayconstant.RelayModeGemini: - if strings.Contains(c.Request.URL.Path, "embed") { - err = relay.GeminiEmbeddingHandler(c) - } else { - err = relay.GeminiHelper(c) - } + err = relay.ResponsesHelper(c, info) default: - err = relay.TextHelper(c) + err = relay.TextHelper(c, info) } - - if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) { - // 保存错误日志到mysql中 - userId := c.GetInt("id") - tokenName := c.GetString("token_name") - modelName := c.GetString("original_model") - tokenId := c.GetInt("token_id") - userGroup := c.GetString("group") - channelId := c.GetInt("channel_id") - other := make(map[string]interface{}) - other["error_type"] = err.GetErrorType() - other["error_code"] = err.GetErrorCode() - other["status_code"] = err.StatusCode - other["channel_id"] = channelId - other["channel_name"] = c.GetString("channel_name") - other["channel_type"] = c.GetInt("channel_type") - adminInfo := make(map[string]interface{}) - adminInfo["use_channel"] = c.GetStringSlice("use_channel") - isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) - if isMultiKey { - adminInfo["is_multi_key"] = true - adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) - } - other["admin_info"] = adminInfo - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) - } - return err } -func Relay(c *gin.Context) { - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) +func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { + var err *types.NewAPIError + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c, info) + } else { + err = relay.GeminiHelper(c, info) + } + return err +} + +func Relay(c *gin.Context, relayFormat types.RelayFormat) { + requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError + + var ( + newAPIError *types.NewAPIError + ws *websocket.Conn + ) + + if relayFormat == types.RelayFormatOpenAIRealtime { + var err error + ws, err = upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) + return + } + defer ws.Close() + } + + defer func() { + if newAPIError != nil { + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + helper.WssError(c, ws, newAPIError.ToOpenAIError()) + case types.RelayFormatClaude: + c.JSON(newAPIError.StatusCode, gin.H{ + "type": "error", + "error": newAPIError.ToClaudeError(), + }) + default: + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), + }) + } + } + }() + + request, err := helper.GetAndValidateRequest(c, relayFormat) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + return + } + + //includeUsage := true + //// 判断用户是否需要返回使用情况 + //if textRequest.StreamOptions != nil { + // includeUsage = textRequest.StreamOptions.IncludeUsage + //} + // + //// 如果不支持StreamOptions,将StreamOptions设置为nil + //if !relayInfo.SupportStreamOptions || !textRequest.Stream { + // textRequest.StreamOptions = nil + //} else { + // // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + // if constant.ForceStreamOption { + // textRequest.StreamOptions = &dto.StreamOptions{ + // IncludeUsage: true, + // } + // } + //} + // + //relayInfo.ShouldIncludeUsage = includeUsage + + relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) + return + } + + meta := request.GetTokenCountMeta() + + if setting.ShouldCheckPromptSensitive() { + words, err := service.CheckSensitiveText(meta.CombineText) + if err != nil { + logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) + newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return + } + } + + tokens, err := service.CountRequestToken(c, meta, relayInfo) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed) + return + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) + return + } + + preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newApiErr != nil { + return + } + + defer func() { + if newApiErr != nil { + service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + } + }() for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) newAPIError = err break } - newAPIError = relayRequest(c, relayMode, channel) + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + newAPIError = relay.WssHelper(c, ws) + case types.RelayFormatClaude: + newAPIError = relay.ClaudeHelper(c, relayInfo) + case types.RelayFormatGemini: + newAPIError = geminiRelayHandler(c, relayInfo) + default: + newAPIError = relayHandler(c, relayInfo) + } if newAPIError == nil { - return // 成功处理请求,直接返回 + return + } else { + if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) { + // 保存错误日志到mysql中 + userId := c.GetInt("id") + tokenName := c.GetString("token_name") + modelName := c.GetString("original_model") + tokenId := c.GetInt("token_id") + userGroup := c.GetString("group") + channelId := c.GetInt("channel_id") + other := make(map[string]interface{}) + other["error_type"] = newAPIError.GetErrorType() + other["error_code"] = newAPIError.GetErrorCode() + other["status_code"] = newAPIError.StatusCode + other["channel_id"] = channelId + other["channel_name"] = c.GetString("channel_name") + other["channel_type"] = c.GetInt("channel_type") + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = c.GetStringSlice("use_channel") + isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) + if isMultiKey { + adminInfo["is_multi_key"] = true + adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) + } + other["admin_info"] = adminInfo + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other) + } } go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) @@ -107,21 +223,11 @@ func Relay(c *gin.Context) { break } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "error": newAPIError.ToOpenAIError(), - }) + logger.LogInfo(c, retryLogStr) } } @@ -132,122 +238,6 @@ var upgrader = websocket.Upgrader{ }, } -func WssRelay(c *gin.Context) { - // 将 HTTP 连接升级为 WebSocket 连接 - - ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) - defer ws.Close() - - if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) - return - } - - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = wssRequest(c, ws, relayMode, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - helper.WssError(c, ws, newAPIError.ToOpenAIError()) - } -} - -func RelayClaude(c *gin.Context) { - //relayMode := constant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = claudeRequest(c, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "type": "error", - "error": newAPIError.ToClaudeError(), - }) - } -} - -func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relayHandler(c, relayMode) -} - -func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.WssHelper(c, ws) -} - -func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.ClaudeHelper(c) -} - func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) @@ -270,10 +260,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -327,7 +317,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { service.DisableChannel(channelError, err.Error()) } @@ -362,7 +352,7 @@ func RelayMidjourney(c *gin.Context) { "code": err.Code, }) channelId := c.GetInt("channel_id") - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) } } @@ -404,7 +394,7 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, newAPIError := getChannel(c, group, originalModel, i) if newAPIError != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } @@ -412,7 +402,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, _ := common.GetRequestBody(c) @@ -422,7 +412,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) + logger.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/controller/task.go b/controller/task.go index 5fbdb424..a5b28ae2 100644 --- a/controller/task.go +++ b/controller/task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "sort" @@ -25,7 +26,7 @@ func UpdateTaskBulk() { //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") + logger.SysLog("任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) @@ -54,9 +55,9 @@ func UpdateTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -65,7 +66,7 @@ func UpdateTaskBulk() { UpdateTaskByPlatform(platform, taskChannelM, taskM) } - common.SysLog("任务进度轮询完成") + logger.SysLog("任务进度轮询完成") } } @@ -77,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) default: if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) + logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) } } } @@ -86,27 +87,27 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { - common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) } } return nil } func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } channel, err := model.CacheGetChannel(channelId) if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) err = model.TaskBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -118,27 +119,27 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) return err } if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) return err } @@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } @@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + logger.SysError("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index 914bf6e6..dca42955 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" "io" - "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "one-api/relay/channel" @@ -18,14 +18,14 @@ import ( func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) } } return nil } func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -37,7 +37,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -47,7 +47,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha } for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) } } return nil @@ -61,7 +61,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task := taskM[taskId] if task == nil { - common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) return fmt.Errorf("task %s not found", taskId) } resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ @@ -124,13 +124,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.FinishTime = now } task.FailReason = taskResult.Reason - common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) quota := task.Quota if quota != 0 { if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + logger.LogError(ctx, "Failed to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } default: @@ -140,7 +140,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.Progress = taskResult.Progress } if err := task.Update(); err != nil { - common.SysError("UpdateVideoTask task error: " + err.Error()) + logger.SysError("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index 62eb5474..db575fec 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,6 +3,7 @@ package controller import ( "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" @@ -102,7 +103,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + logger.SysError("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/topup.go b/controller/topup.go index 827dda39..3f3c8623 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -5,6 +5,7 @@ import ( "log" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) { return } log.Printf("易支付回调更新用户成功 %v", topUp) - model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) diff --git a/controller/twofa.go b/controller/twofa.go index 9f48eed8..0ab66029 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" @@ -70,7 +71,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成2FA密钥失败", }) - common.SysError("生成TOTP密钥失败: " + err.Error()) + logger.SysError("生成TOTP密钥失败: " + err.Error()) return } @@ -81,7 +82,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + logger.SysError("生成备用码失败: " + err.Error()) return } @@ -115,7 +116,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + logger.SysError("保存备用码失败: " + err.Error()) return } @@ -294,7 +295,7 @@ func Get2FAStatus(c *gin.Context) { // 获取剩余备用码数量 backupCount, err := model.GetUnusedBackupCodeCount(userId) if err != nil { - common.SysError("获取备用码数量失败: " + err.Error()) + logger.SysError("获取备用码数量失败: " + err.Error()) } else { status["backup_codes_remaining"] = backupCount } @@ -368,7 +369,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + logger.SysError("生成备用码失败: " + err.Error()) return } @@ -378,7 +379,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + logger.SysError("保存备用码失败: " + err.Error()) return } diff --git a/controller/user.go b/controller/user.go index 29cf83e1..8ce44fa6 100644 --- a/controller/user.go +++ b/controller/user.go @@ -7,6 +7,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/setting" "strconv" @@ -192,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -235,7 +236,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + logger.SysError("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - common.SysError("failed to generate key: " + err.Error()) + logger.SysError("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) @@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/dto/audio.go b/dto/audio.go index c36b3da5..81872c69 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,5 +1,11 @@ package dto +import ( + "one-api/types" + + "github.com/gin-gonic/gin" +) + type AudioRequest struct { Model string `json:"model"` Input string `json:"input"` @@ -8,6 +14,18 @@ type AudioRequest struct { ResponseFormat string `json:"response_format,omitempty"` } +func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta { + meta := &types.TokenCountMeta{ + CombineText: r.Input, + TokenType: types.TokenTypeTextNumber, + } + return meta +} + +func (r *AudioRequest) IsStream(c *gin.Context) bool { + return false +} + type AudioResponse struct { Text string `json:"text"` } diff --git a/dto/claude.go b/dto/claude.go index 58a09217..2b3adf19 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -5,6 +5,9 @@ import ( "fmt" "one-api/common" "one-api/types" + "strings" + + "github.com/gin-gonic/gin" ) type ClaudeMetadata struct { @@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string { } func (c *ClaudeMediaMessage) GetJsonRowString() string { - jsonContent, _ := json.Marshal(c) + jsonContent, _ := common.Marshal(c) return string(jsonContent) } @@ -199,6 +202,129 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta = types.TokenCountMeta{ + TokenType: types.TokenTypeTextNumber, + MaxTokens: int(c.MaxTokens), + } + + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + // system + if c.System != nil { + if c.IsStringSystem() { + sys := c.GetStringSystem() + if sys != "" { + texts = append(texts, sys) + } + } else { + systemMedia := c.ParseSystem() + for _, media := range systemMedia { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + } + } + } + } + + // messages + for _, message := range c.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.IsStringContent() { + content := message.GetStringContent() + if content != "" { + texts = append(texts, content) + } + continue + } + + content, _ := message.ParseContent() + for _, media := range content { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + case "tool_use": + if media.Name != "" { + texts = append(texts, media.Name) + } + if media.Input != nil { + b, _ := common.Marshal(media.Input) + texts = append(texts, string(b)) + } + case "tool_result": + if media.Content != nil { + b, _ := common.Marshal(media.Content) + texts = append(texts, string(b)) + } + } + } + } + + // tools + if c.Tools != nil { + tools := c.GetTools() + normalTools, webSearchTools := ProcessTools(tools) + if normalTools != nil { + for _, t := range normalTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.Description != "" { + texts = append(texts, t.Description) + } + if t.InputSchema != nil { + b, _ := common.Marshal(t.InputSchema) + texts = append(texts, string(b)) + } + } + } + if webSearchTools != nil { + for _, t := range webSearchTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.UserLocation != nil { + b, _ := common.Marshal(t.UserLocation) + texts = append(texts, string(b)) + } + } + } + } + + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool { + return claudeRequest.Stream +} + func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { for _, message := range c.Messages { content, _ := message.ParseContent() diff --git a/dto/embedding.go b/dto/embedding.go index 9d722292..fff37776 100644 --- a/dto/embedding.go +++ b/dto/embedding.go @@ -1,5 +1,12 @@ package dto +import ( + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` @@ -24,9 +31,26 @@ type EmbeddingRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } -func (r EmbeddingRequest) ParseInput() []string { +func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + inputs := r.ParseInput() + for _, input := range inputs { + texts = append(texts, input) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + +func (r *EmbeddingRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *EmbeddingRequest) ParseInput() []string { if r.Input == nil { - return nil + return make([]string, 0) } var input []string switch r.Input.(type) { diff --git a/dto/gemini.go b/dto/gemini.go index 6cb3e17a..b327de62 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,7 +2,10 @@ package dto import ( "encoding/json" + "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" + "one-api/types" "strings" ) @@ -14,19 +17,75 @@ type GeminiChatRequest struct { SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } +func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { + var files []*types.FileMeta = make([]*types.FileMeta, 0) + + var maxTokens int + + if r.GenerationConfig.MaxOutputTokens > 0 { + maxTokens = int(r.GenerationConfig.MaxOutputTokens) + } + + var inputTexts []string + for _, content := range r.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + if part.InlineData != nil && part.InlineData.Data != "" { + if strings.HasPrefix(part.InlineData.MimeType, "image/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeAudio, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeVideo, + Data: part.InlineData.Data, + }) + } else { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: part.InlineData.Data, + }) + } + } + } + } + + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + Files: files, + MaxTokens: maxTokens, + } +} + +func (r *GeminiChatRequest) IsStream(c *gin.Context) bool { + if c.Query("alt") == "sse" { + return true + } + return false +} + func (r *GeminiChatRequest) GetTools() []GeminiChatTool { var tools []GeminiChatTool if strings.HasSuffix(string(r.Tools), "[") { // is array if err := common.Unmarshal(r.Tools, &tools); err != nil { - common.LogError(nil, "error_unmarshalling_tools: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_tools: "+err.Error()) return nil } } else if strings.HasPrefix(string(r.Tools), "{") { // is object singleTool := GeminiChatTool{} if err := common.Unmarshal(r.Tools, &singleTool); err != nil { - common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) return nil } tools = []GeminiChatTool{singleTool} @@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) { // Marshal the tools to JSON data, err := common.Marshal(tools) if err != nil { - common.LogError(nil, "error_marshalling_tools: "+err.Error()) + logger.LogError(nil, "error_marshalling_tools: "+err.Error()) return } r.Tools = data diff --git a/dto/dalle.go b/dto/openai_image.go similarity index 51% rename from dto/dalle.go rename to dto/openai_image.go index ce2f6361..7431935b 100644 --- a/dto/dalle.go +++ b/dto/openai_image.go @@ -1,11 +1,17 @@ package dto -import "encoding/json" +import ( + "encoding/json" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` + N uint `json:"n,omitempty"` Size string `json:"size,omitempty"` Quality string `json:"quality,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -18,6 +24,42 @@ type ImageRequest struct { Watermark *bool `json:"watermark,omitempty"` } +func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { + var sizeRatio = 1.0 + var qualityRatio = 1.0 + + if strings.HasPrefix(i.Model, "dall-e") { + // Size + if i.Size == "256x256" { + sizeRatio = 0.4 + } else if i.Size == "512x512" { + sizeRatio = 0.45 + } else if i.Size == "1024x1024" { + sizeRatio = 1 + } else if i.Size == "1024x1792" || i.Size == "1792x1024" { + sizeRatio = 2 + } + + if i.Model == "dall-e-3" && i.Quality == "hd" { + qualityRatio = 2.0 + if i.Size == "1024x1792" || i.Size == "1792x1024" { + qualityRatio = 1.5 + } + } + } + + // not support token count for dalle + return &types.TokenCountMeta{ + CombineText: i.Prompt, + MaxTokens: 1584, + ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N), + } +} + +func (i *ImageRequest) IsStream(c *gin.Context) bool { + return false +} + type ImageResponse struct { Data []ImageData `json:"data"` Created int64 `json:"created"` diff --git a/dto/openai_request.go b/dto/openai_request.go index 7a23ca5c..0c01c503 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,8 +2,12 @@ package dto import ( "encoding/json" + "fmt" "one-api/common" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type ResponseFormat struct { @@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct { Extra map[string]json.RawMessage `json:"-"` } +func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta types.TokenCountMeta + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + texts = append(texts, v) + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + texts = append(texts, str) + } + } + default: + texts = append(texts, fmt.Sprintf("%v", r.Prompt)) + } + } + + if r.Input != nil { + inputs := r.ParseInput() + texts = append(texts, inputs...) + } + + if r.MaxCompletionTokens > r.MaxTokens { + tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens) + } else { + tokenCountMeta.MaxTokens = int(r.MaxTokens) + } + + for _, message := range r.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.Content != nil { + if message.Name != nil { + tokenCountMeta.NameCount++ + texts = append(texts, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == ContentTypeImageURL { + imageUrl := m.GetImageMedia() + if imageUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeImage, + } + meta.Data = imageUrl.Url + meta.Detail = imageUrl.Detail + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeInputAudio { + inputAudio := m.GetInputAudio() + if inputAudio != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeAudio, + } + meta.Data = inputAudio.Data + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeFile { + file := m.GetFile() + if file != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeFile, + } + meta.Data = file.FileData + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeVideoUrl { + videoUrl := m.GetVideoUrl() + if videoUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeVideo, + } + meta.Data = videoUrl.Url + fileMeta = append(fileMeta, meta) + } + } else { + texts = append(texts, m.Text) + } + } + } + } + + if r.Tools != nil { + openaiTools := r.Tools + for _, tool := range openaiTools { + tokenCountMeta.ToolsCount++ + texts = append(texts, tool.Function.Name) + if tool.Function.Description != "" { + texts = append(texts, tool.Function.Description) + } + if tool.Function.Parameters != nil { + texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters)) + } + } + //toolTokens := CountTokenInput(countStr, request.Model) + //tkm += 8 + //tkm += toolTokens + } + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.Marshal(r) @@ -202,10 +316,25 @@ func (m *MediaContent) GetFile() *MessageFile { return nil } +func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { + if m.VideoUrl != nil { + if _, ok := m.VideoUrl.(*MessageVideoUrl); ok { + return m.VideoUrl.(*MessageVideoUrl) + } + if itemMap, ok := m.VideoUrl.(map[string]any); ok { + out := &MessageVideoUrl{ + Url: common.Interface2String(itemMap["url"]), + } + return out + } + } + return nil +} + type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` - MimeType string + Url string `json:"url"` + Detail string `json:"detail"` + //MimeType string } func (m *MessageImageUrl) IsRemoteImage() bool { @@ -233,6 +362,7 @@ const ( ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 + //ContentTypeAudioUrl = "audio_url" ) func (m *Message) GetPrefix() bool { @@ -623,7 +753,7 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` + Input any `json:"input,omitempty"` Include json.RawMessage `json:"include,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` @@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct { Prompt json.RawMessage `json:"prompt,omitempty"` } +func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { + var fileMeta = make([]*types.FileMeta, 0) + var texts = make([]string, 0) + + if r.Input != nil { + inputs := r.ParseInput() + for _, input := range inputs { + if input.Type == "input_image" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: input.ImageUrl, + Detail: input.Detail, + }) + } else if input.Type == "input_file" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: input.FileUrl, + }) + } else { + texts = append(texts, input.Text) + } + } + } + + if len(r.Instructions) > 0 { + texts = append(texts, string(r.Instructions)) + } + + if len(r.Metadata) > 0 { + texts = append(texts, string(r.Metadata)) + } + + if len(r.Text) > 0 { + texts = append(texts, string(r.Text)) + } + + if len(r.ToolChoice) > 0 { + texts = append(texts, string(r.ToolChoice)) + } + + if len(r.Prompt) > 0 { + texts = append(texts, string(r.Prompt)) + } + + if len(r.Tools) > 0 { + toolStr, _ := common.Marshal(r.Tools) + texts = append(texts, string(toolStr)) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + Files: fileMeta, + MaxTokens: int(r.MaxOutputTokens), + } +} + +func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + type Reasoning struct { Effort string `json:"effort,omitempty"` Summary string `json:"summary,omitempty"` } -//type ResponsesToolsCall struct { -// Type string `json:"type"` -// // Web Search -// UserLocation json.RawMessage `json:"user_location,omitempty"` -// SearchContextSize string `json:"search_context_size,omitempty"` -// // File Search -// VectorStoreIds []string `json:"vector_store_ids,omitempty"` -// MaxNumResults uint `json:"max_num_results,omitempty"` -// Filters json.RawMessage `json:"filters,omitempty"` -// // Computer Use -// DisplayWidth uint `json:"display_width,omitempty"` -// DisplayHeight uint `json:"display_height,omitempty"` -// Environment string `json:"environment,omitempty"` -// // Function -// Name string `json:"name,omitempty"` -// Description string `json:"description,omitempty"` -// Parameters json.RawMessage `json:"parameters,omitempty"` -// Function json.RawMessage `json:"function,omitempty"` -// Container json.RawMessage `json:"container,omitempty"` -//} +type MediaInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + FileUrl string `json:"file_url,omitempty"` + ImageUrl string `json:"image_url,omitempty"` + Detail string `json:"detail,omitempty"` // 仅 input_image 有效 +} + +// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput. +// Reference implementation mirrors Message.ParseContent: +// - input can be a string, treated as an input_text item +// - input can be an array of objects with a `type` field +// supported types: input_text, input_image, input_file +func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { + if r.Input == nil { + return nil + } + + var inputs []MediaInput + + // Try string first + if str, ok := r.Input.(string); ok { + inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + return inputs + } + + // Try array of parts + if array, ok := r.Input.([]any); ok { + for _, itemAny := range array { + // Already parsed MediaInput + if media, ok := itemAny.(MediaInput); ok { + inputs = append(inputs, media) + continue + } + // Generic map + item, ok := itemAny.(map[string]any) + if !ok { + continue + } + typeVal, ok := item["type"].(string) + if !ok { + continue + } + switch typeVal { + case "input_text": + text, _ := item["text"].(string) + inputs = append(inputs, MediaInput{Type: "input_text", Text: text}) + case "input_image": + // image_url may be string or object with url field + var imageUrl string + switch v := item["image_url"].(type) { + case string: + imageUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + imageUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl}) + case "input_file": + // file_url may be string or object with url field + var fileUrl string + switch v := item["file_url"].(type) { + case string: + fileUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + fileUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl}) + } + } + } + + return inputs +} diff --git a/dto/request_common.go b/dto/request_common.go new file mode 100644 index 00000000..e5dde8b5 --- /dev/null +++ b/dto/request_common.go @@ -0,0 +1,11 @@ +package dto + +import ( + "github.com/gin-gonic/gin" + "one-api/types" +) + +type Request interface { + GetTokenCountMeta() *types.TokenCountMeta + IsStream(c *gin.Context) bool +} diff --git a/dto/rerank.go b/dto/rerank.go index 5ea68cba..ca4da9e1 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -1,5 +1,12 @@ package dto +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/types" + "strings" +) + type RerankRequest struct { Documents []any `json:"documents"` Query string `json:"query"` @@ -10,6 +17,26 @@ type RerankRequest struct { OverLapTokens int `json:"overlap_tokens,omitempty"` } +func (r *RerankRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + for _, document := range r.Documents { + texts = append(texts, fmt.Sprintf("%v", document)) + } + + if r.Query != "" { + texts = append(texts, r.Query) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + func (r *RerankRequest) GetReturnDocuments() bool { if r.ReturnDocuments == nil { return false diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..ca81d624 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,115 @@ +package logger + +import ( + "context" + "encoding/json" + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "io" + "log" + "one-api/common" + "os" + "path/filepath" + "sync" + "time" +) + +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" + loggerDebug = "DEBUG" +) + +const maxLogCount = 1000000 + +var logCount int +var setupLogLock sync.Mutex +var setupLogWorking bool + +func SetupLogger() { + if *common.LogDir != "" { + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) + } +} + +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func LogDebug(ctx context.Context, msg string) { + if common.DebugEnabled { + logHelper(ctx, loggerDebug, msg) + } +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(common.RequestIdKey) + if id == nil { + id = "SYSTEM" + } + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + logCount++ // we don't need accurate count, so no lock here + if logCount > maxLogCount && !setupLogWorking { + logCount = 0 + setupLogWorking = true + gopool.Go(func() { + SetupLogger() + }) + } +} + +func LogQuota(quota int) string { + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) + } else { + return fmt.Sprintf("%d 点额度", quota) + } +} + +func FormatQuota(quota int) string { + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit) + } else { + return fmt.Sprintf("%d", quota) + } +} + +// LogJson 仅供测试使用 only for test +func LogJson(ctx context.Context, msg string, obj any) { + jsonStr, err := json.Marshal(obj) + if err != nil { + LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) + return + } + LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr))) +} diff --git a/main.go b/main.go index ca3da601..9a5bd652 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/constant" "one-api/controller" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/router" @@ -35,22 +36,22 @@ func main() { err := InitResources() if err != nil { - common.FatalLog("failed to initialize resources: " + err.Error()) + logger.FatalLog("failed to initialize resources: " + err.Error()) return } - common.SysLog("New API " + common.Version + " started") + logger.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } if common.DebugEnabled { - common.SysLog("running in debug mode") + logger.SysLog("running in debug mode") } defer func() { err := model.CloseDB() if err != nil { - common.FatalLog("failed to close database: " + err.Error()) + logger.FatalLog("failed to close database: " + err.Error()) } }() @@ -59,18 +60,18 @@ func main() { common.MemoryCacheEnabled = true } if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { - common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() @@ -89,14 +90,14 @@ func main() { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } @@ -110,7 +111,7 @@ func main() { } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } @@ -119,13 +120,13 @@ func main() { log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) }) go common.Monitor() - common.SysLog("pprof enabled") + logger.SysLog("pprof enabled") } // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - common.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -155,7 +156,7 @@ func main() { } err = server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } @@ -164,14 +165,14 @@ func InitResources() error { // This is a placeholder function for future resource initialization err := godotenv.Load(".env") if err != nil { - common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") - common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } // 加载环境变量 common.InitEnv() - common.SetupLogger() + logger.SetupLogger() // Initialize model settings ratio_setting.InitRatioSettings() @@ -183,7 +184,7 @@ func InitResources() error { // Initialize SQL Database err = model.InitDB() if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) + logger.FatalLog("failed to initialize database: " + err.Error()) return err } diff --git a/middleware/recover.go b/middleware/recover.go index 51fc7190..6c9c7ef6 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/logger" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..a136a900 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/logger" ) type turnstileCheckResponse struct { @@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 082f5657..e23bbff7 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" ) func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { @@ -15,7 +16,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { }, }) c.Abort() - common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) + logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) } func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { @@ -25,5 +26,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri "code": code, }) c.Abort() - common.LogError(c.Request.Context(), description) + logger.LogError(c.Request.Context(), description) } diff --git a/model/ability.go b/model/ability.go index ce2f299c..ac5530d8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strings" "sync" @@ -294,13 +295,13 @@ func FixAbility() (int, int, error) { if common.UsingSQLite { err := DB.Exec("DELETE FROM abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) return 0, 0, err } } else { err := DB.Exec("TRUNCATE TABLE abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) return 0, 0, err } } @@ -320,7 +321,7 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } @@ -328,7 +329,7 @@ func FixAbility() (int, int, error) { for _, channel := range chunk { err = channel.AddAbilities(nil) if err != nil { - common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index 6239f05c..c0d253fc 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/types" "strings" "sync" @@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - common.SysError("failed to unmarshal other info: " + err.Error()) + logger.SysError("failed to unmarshal other info: " + err.Error()) } } return otherInfo @@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - common.SysError("failed to marshal other info: " + err.Error()) + logger.SysError("failed to marshal other info: " + err.Error()) return } channel.OtherInfo = string(otherInfoBytes) @@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } @@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } } }() @@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) return false } } @@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - common.SysError("failed to update abilities: " + err.Error()) + logger.SysError("failed to update abilities: " + err.Error()) } } } @@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } @@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } channel.OtherSettings = string(settingBytes) @@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - common.SysError("failed to unmarshal param override: " + err.Error()) + logger.SysError("failed to unmarshal param override: " + err.Error()) } } return paramOverride diff --git a/model/channel_cache.go b/model/channel_cache.go index 86866e40..22216027 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -6,6 +6,7 @@ import ( "math/rand" "one-api/common" "one-api/constant" + "one-api/logger" "one-api/setting" "one-api/setting/ratio_setting" "sort" @@ -84,13 +85,13 @@ func InitChannelCache() { } channelsIDM = newChannelId2channel channelSyncLock.Unlock() - common.SysLog("channels synced from database") + logger.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") InitChannelCache() } } diff --git a/model/log.go b/model/log.go index 2070cd6f..d9495968 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/logger" "os" "strings" "time" @@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) + logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) username := c.GetString("username") otherStr := common.MapToJsonStr(other) // 判断是否需要记录 IP @@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } } @@ -142,7 +143,6 @@ type RecordConsumeLogParams struct { Quota int `json:"quota"` Content string `json:"content"` TokenId int `json:"token_id"` - UserQuota int `json:"user_quota"` UseTimeSeconds int `json:"use_time_seconds"` IsStream bool `json:"is_stream"` Group string `json:"group"` @@ -150,7 +150,7 @@ type RecordConsumeLogParams struct { } func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) + logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } @@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } if common.DataExportEnabled { gopool.Go(func() { diff --git a/model/main.go b/model/main.go index dbf27152..1e582e1a 100644 --- a/model/main.go +++ b/model/main.go @@ -5,6 +5,7 @@ import ( "log" "one-api/common" "one-api/constant" + "one-api/logger" "os" "strings" "sync" @@ -84,7 +85,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -108,7 +109,7 @@ func CheckSetup() { if setup == nil { // No setup record exists, check if we have a root user if RootUserExists() { - common.SysLog("system is not initialized, but root user exists") + logger.SysLog("system is not initialized, but root user exists") // Create setup record newSetup := Setup{ Version: common.Version, @@ -116,16 +117,16 @@ func CheckSetup() { } err := DB.Create(&newSetup).Error if err != nil { - common.SysLog("failed to create setup record: " + err.Error()) + logger.SysLog("failed to create setup record: " + err.Error()) } constant.Setup = true } else { - common.SysLog("system is not initialized and no root user exists") + logger.SysLog("system is not initialized and no root user exists") constant.Setup = false } } else { // Setup record exists, system is initialized - common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) + logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) constant.Setup = true } } @@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { if dsn != "" { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") if !isLog { common.UsingPostgreSQL = true } else { @@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } if strings.HasPrefix(dsn, "local") { - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") if !isLog { common.UsingSQLite = true } else { @@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") // check parseTime if !strings.Contains(dsn, "parseTime") { if strings.Contains(dsn, "?") { @@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL @@ -216,11 +217,11 @@ func InitDB() (err error) { if common.UsingMySQL { //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = migrateDB() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } @@ -253,11 +254,11 @@ func InitLogDB() (err error) { if !common.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = migrateLOGDB() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } @@ -354,7 +355,7 @@ func migrateDBFast() error { return err } } - common.SysLog("database migrated") + logger.SysLog("database migrated") return nil } @@ -503,6 +504,6 @@ func PingDB() error { } lastPingTime = time.Now() - common.SysLog("Database pinged successfully") + logger.SysLog("Database pinged successfully") return nil } diff --git a/model/option.go b/model/option.go index 5c84d166..8fcd13a8 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/logger" "one-api/setting" "one-api/setting/config" "one-api/setting/operation_setting" @@ -150,7 +151,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -158,7 +159,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/pricing.go b/model/pricing.go index 0936d298..31aa5cdf 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -3,6 +3,7 @@ package model import ( "encoding/json" "fmt" + "one-api/logger" "strings" "one-api/common" @@ -92,7 +93,7 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 diff --git a/model/redemption.go b/model/redemption.go index bf237668..1ab84f45 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strconv" "gorm.io/gorm" @@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id)) return redemption.Quota, nil } diff --git a/model/token.go b/model/token.go index e85a445e..63c17e2d 100644 --- a/model/token.go +++ b/model/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } @@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } @@ -178,7 +179,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + logger.SysError("failed to update token cache: " + err.Error()) } }) } @@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + logger.SysError("failed to update token cache: " + err.Error()) } }) } @@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - common.SysError("failed to delete token cache: " + err.Error()) + logger.SysError("failed to delete token cache: " + err.Error()) } }) } @@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to increase token quota: " + err.Error()) + logger.SysError("failed to increase token quota: " + err.Error()) } }) } @@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to decrease token quota: " + err.Error()) + logger.SysError("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/topup.go b/model/topup.go index c34c0ce6..802c866f 100644 --- a/model/topup.go +++ b/model/topup.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "gorm.io/gorm" ) @@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) { return errors.New("充值失败," + err.Error()) } - RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount)) + RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) return nil } diff --git a/model/twofa.go b/model/twofa.go index d09ff9fe..b2ea54e0 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "time" "gorm.io/gorm" @@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { if !common.ValidateTOTPCode(t.Secret, code) { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + logger.SysError("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + logger.SysError("更新2FA使用记录失败: " + err.Error()) } return true, nil @@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { if !valid { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + logger.SysError("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + logger.SysError("更新2FA使用记录失败: " + err.Error()) } return true, nil diff --git a/model/usedata.go b/model/usedata.go index 1255b0be..f0027a8d 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -4,6 +4,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/logger" "sync" "time" ) @@ -24,12 +25,12 @@ func UpdateQuotaData() { // recover defer func() { if r := recover(); r != nil { - common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) + logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) } }() for { if common.DataExportEnabled { - common.SysLog("正在更新数据看板数据...") + logger.SysLog("正在更新数据看板数据...") SaveQuotaDataCache() } time.Sleep(time.Duration(common.DataExportInterval) * time.Minute) @@ -91,7 +92,7 @@ func SaveQuotaDataCache() { } } CacheQuotaData = make(map[string]*QuotaData) - common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) + logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) } func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) { @@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int, "token_used": gorm.Expr("token_used + ?", tokenUsed), }).Error if err != nil { - common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) + logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) } } diff --git a/model/user.go b/model/user.go index 6021f495..244380ad 100644 --- a/model/user.go +++ b/model/user.go @@ -6,6 +6,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "strconv" "strings" @@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) } } return setting @@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) @@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) { func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { - return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) + return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 @@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } @@ -517,7 +518,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + logger.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - common.SysError("failed to update user quota cache: " + err.Error()) + logger.SysError("failed to update user quota cache: " + err.Error()) } }) } @@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - common.SysError("failed to update user group cache: " + err.Error()) + logger.SysError("failed to update user group cache: " + err.Error()) } }) } @@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - common.SysError("failed to update user setting cache: " + err.Error()) + logger.SysError("failed to update user setting cache: " + err.Error()) } }) } @@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to increase user quota: " + err.Error()) + logger.SysError("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to decrease user quota: " + err.Error()) + logger.SysError("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + logger.SysError("failed to update user used quota and request count: " + err.Error()) return } @@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + logger.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + logger.SysError("failed to update user request count: " + err.Error()) } } @@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - common.SysError("failed to update user name cache: " + err.Error()) + logger.SysError("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index a631457c..dec7597b 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,6 +5,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "time" "github.com/gin-gonic/gin" @@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) } } return setting @@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index 1f8a0963..abd96b79 100644 --- a/model/utils.go +++ b/model/utils.go @@ -3,6 +3,7 @@ package model import ( "errors" "one-api/common" + "one-api/logger" "sync" "time" @@ -65,7 +66,7 @@ func batchUpdate() { return } - common.SysLog("batch update started") + logger.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -77,12 +78,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + logger.SysError("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + logger.SysError("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -93,7 +94,7 @@ func batchUpdate() { } } } - common.SysLog("batch update finished") + logger.SysLog("batch update finished") } func RecordExist(err error) (bool, error) { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 88777838..1bc2d90b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -4,107 +4,40 @@ import ( "errors" "fmt" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/types" - "strings" "github.com/gin-gonic/gin" ) -func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { - audioRequest := &dto.AudioRequest{} - err := common.UnmarshalBodyReusable(c, audioRequest) - if err != nil { - return nil, err - } - switch info.RelayMode { - case relayconstant.RelayModeAudioSpeech: - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(audioRequest.Input) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - default: - err = c.Request.ParseForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - if audioRequest.Model == "" { - audioRequest.Model = formData.Get("model") - } +func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - audioRequest.ResponseFormat = formData.Get("response_format") - if audioRequest.ResponseFormat == "" { - audioRequest.ResponseFormat = "json" - } - } - return audioRequest, nil -} - -func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) - audioRequest, err := getAndValidAudioRequest(c, relayInfo) - - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + audioRequest, ok := info.Request.(*dto.AudioRequest) + if !ok { + return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - preConsumedTokens = promptTokens - relayInfo.PromptTokens = promptTokens - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr - } - defer func() { - if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - err = helper.ModelMappedHelper(c, relayInfo, audioRequest) + err := helper.ModelMappedHelper(c, info, audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) + ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - resp, err := adaptor.DoRequest(c, relayInfo, ioReader) + resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } @@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 754f29c8..841896cf 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/service" "one-api/types" @@ -43,7 +43,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - common.SysError("updateTask client.Do err: " + err.Error()) + logger.SysError("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -53,7 +53,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error var response AliResponse err = json.Unmarshal(responseBody, &response) if err != nil { - common.SysError("updateTask NewDecoder err: " + err.Error()) + logger.SysError("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } @@ -109,7 +109,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc if responseFormat == "b64_json" { _, b64, err := service.GetImageFromUrl(data.Url) if err != nil { - common.LogError(c, "get_image_data_failed: "+err.Error()) + logger.LogError(c, "get_image_data_failed: "+err.Error()) continue } b64Json = b64 @@ -134,14 +134,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { - common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 4f448e01..e7d6b514 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index fcf63854..17fcef2a 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -7,7 +7,9 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/helper" + "one-api/service" "strings" "one-api/types" @@ -46,7 +48,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) model := c.GetString("model") if model == "" { @@ -148,7 +150,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -161,7 +163,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -171,7 +173,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -181,7 +183,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 3ccd2d78..fd745cf7 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,6 +7,7 @@ import ( "io" "net/http" common2 "one-api/common" + "one-api/logger" "one-api/relay/common" "one-api/relay/constant" "one-api/relay/helper" @@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { err := helper.PingData(c) if err != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) + logger.LogError(c, "SSE ping error: "+err.Error()) done <- err return } diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index a7cd5996..696c2496 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -118,7 +119,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -129,11 +130,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - common.SysError("error sending stream response: " + err.Error()) + logger.SysError("error sending stream response: " + err.Error()) } return true }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -143,7 +144,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil @@ -168,7 +169,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index e4d3975e..5d839908 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -609,7 +610,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { @@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud err = helper.ObjectData(c, response) if err != nil { - common.LogError(c, "send_stream_response_failed: "+err.Error()) + logger.LogError(c, "send_stream_response_failed: "+err.Error()) } } return nil @@ -653,7 +654,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - common.SysError("claude response usage is not complete, maybe upstream error") + logger.SysError("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } @@ -667,7 +668,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + logger.SysError("send final response failed: " + err.Error()) } } helper.Done(c) @@ -736,12 +737,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) } - common.IOCopyBytesGracefully(c, nil, responseData) + service.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 5e8fe7f9..00f6b6c5 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -5,8 +5,8 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res var response dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &response) if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) continue } for _, choice := range response.Choices { @@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res info.FirstResponseTime = time.Now() } if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) + logger.LogError(c, "error_rendering_stream_response: "+err.Error()) } } if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + logger.LogError(c, "error_scanning_stream_response: "+err.Error()) } usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + logger.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { @@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index fcfb12b7..ccef9b23 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -118,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -153,7 +154,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) @@ -175,7 +176,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { @@ -216,7 +217,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 32cc6937..18ed46af 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -49,7 +50,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) // convert coze response to openai response var response dto.TextResponse var cozeResponse CozeChatDetailResponse @@ -154,7 +155,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } @@ -171,14 +172,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } @@ -203,11 +204,11 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } - common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 47337127..f03d61a4 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -36,14 +37,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - common.SysError("failed to decode base64: " + err.Error()) + logger.SysError("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - common.SysError("failed to create temp file: " + err.Error()) + logger.SysError("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -51,7 +52,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - common.SysError("failed to write to temp file: " + err.Error()) + logger.SysError("failed to write to temp file: " + err.Error()) return nil } @@ -61,7 +62,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - common.SysError("failed to add user field: " + err.Error()) + logger.SysError("failed to add user field: " + err.Error()) return nil } @@ -74,13 +75,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - common.SysError("failed to create form file: " + err.Error()) + logger.SysError("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - common.SysError("failed to copy file content: " + err.Error()) + logger.SysError("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -88,7 +89,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - common.SysError("failed to create request: " + err.Error()) + logger.SysError("failed to create request: " + err.Error()) return nil } @@ -99,7 +100,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - common.SysError("failed to send request: " + err.Error()) + logger.SysError("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -109,7 +110,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - common.SysError("failed to decode response: " + err.Error()) + logger.SysError("failed to decode response: " + err.Error()) return nil } @@ -219,7 +220,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -239,7 +240,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } return true }) @@ -258,7 +259,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4141caf7..05d974f6 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf }, }, Parameters: dto.GeminiImageParameters{ - SampleCount: request.N, + SampleCount: int(request.N), AspectRatio: aspectRatio, PersonGeneration: "allow_adult", // default allow adult }, diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 7f2f51fb..974a22f5 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,7 +18,7 @@ import ( ) func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) @@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return usage, nil } @@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn // 直接发送 GeminiChatResponse 响应 err = helper.StringData(c, data) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } info.SendResponseCount++ return true diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 58efa1a5..82a2d8de 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * finishReason = constant.FinishReasonToolCalls err = handleStream(c, info, emptyResponse) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } response.ClearToolCalls() @@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * err = handleStream(c, info, response) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } if isStop { _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) @@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := handleFinalStream(c, info, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + logger.SysError("send final response failed: " + err.Error()) } //if info.RelayFormat == relaycommon.RelayFormatOpenAI { // helper.Done(c) @@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } @@ -1057,13 +1058,13 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R break } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { @@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 28af1866..11a0117b 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go index c9db6630..d8b598dc 100644 --- a/relay/channel/jimeng/sign.go +++ b/relay/channel/jimeng/sign.go @@ -12,7 +12,7 @@ import ( "io" "net/http" "net/url" - "one-api/common" + "one-api/logger" "sort" "strings" "time" @@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error { if err != nil { return err } - common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) + logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) payloadHash := sha256.Sum256(body) hexPayloadHash := hex.EncodeToString(payloadHash[:]) c.Set(HexPayloadHashKey, hexPayloadHash) diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 78f96d6d..d91aceb3 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) @@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index d4686ce3..066581fa 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, doResponseBody) + service.IOCopyBytesGracefully(c, resp, doResponseBody) return usage, nil } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 696c5cb0..80973aa1 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" @@ -50,7 +51,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { - common.LogError(c, "failed to unmarshal stream response: "+err.Error()) + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) return err } @@ -63,7 +64,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.LogError(c, "failed to marshal gemini response: "+err.Error()) + logger.LogError(c, "failed to marshal gemini response: "+err.Error()) return err } @@ -110,14 +111,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - common.SysError("error processing stream response: " + err.Error()) + logger.SysError("error processing stream response: " + err.Error()) } } return nil @@ -146,7 +147,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -213,7 +214,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return } @@ -227,7 +228,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream case relaycommon.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return } @@ -245,7 +246,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.SysError("error marshalling gemini response: " + err.Error()) + logger.SysError("error marshalling gemini response: " + err.Error()) return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b8e72273..447e0f31 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) model := info.UpstreamModelName var responseId string @@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - common.SysError("error handling stream format: " + err.Error()) + logger.SysError("error handling stream format: " + err.Error()) } } if len(data) > 0 { @@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re shouldSendLastResp := true if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, &containStreamUsage, info, &shouldSendLastResp); err != nil { - common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) + logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } if info.RelayFormat == relaycommon.RelayFormatOpenAI { @@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re // 处理token计算 if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - common.LogError(c, "error processing tokens: "+err.Error()) + logger.LogError(c, "error processing tokens: "+err.Error()) } if !containStreamUsage { @@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) @@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo responseBody = geminiRespStr } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &simpleResponse.Usage, nil } @@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens @@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel c.Writer.WriteHeaderNow() _, err := io.Copy(c.Writer, resp.Body) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // count tokens by audio file duration audioTokens, err := countAudioTokens(c) @@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) usage := &dto.Usage{} usage.PromptTokens = audioTokens @@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken @@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken @@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. localUsage = &dto.RealtimeUsage{} // print now usage } - common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken @@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. case <-targetClosed: case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil - common.LogError(c, "realtime error: "+err.Error()) + logger.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } @@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // Once we've written to the client, we should not return errors anymore // because the upstream has already consumed resources and returned content diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index bae6fcb6..754a6f44 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -16,7 +17,7 @@ import ( ) func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // read response body var responsesResponse dto.OpenAIResponsesResponse @@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // compute usage usage := dto.Usage{} @@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 9b8bce7d..1264b2b4 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -58,15 +59,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) stopChan <- true return } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -78,7 +79,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) stopChan <- true return } @@ -96,7 +97,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, responseText } @@ -105,7 +106,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { @@ -133,6 +134,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index 2e37ad15..b21faccb 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { @@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 9c04c7ad..1deb33fd 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" @@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - common.SysError(fmt.Sprintf("Get Task error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 78ce6238..d3aeab3f 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -13,6 +13,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -106,7 +107,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } @@ -117,17 +118,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } } if err := scanner.Err(); err != nil { - common.SysError("error reading stream: " + err.Error()) + logger.SysError("error reading stream: " + err.Error()) } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil } @@ -138,7 +139,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -156,7 +157,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d098102..4d4e7b92 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -47,7 +48,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } @@ -63,7 +64,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } return true }) @@ -74,12 +75,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -101,7 +102,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.IOCopyBytesGracefully(c, resp, encodeJson) + service.IOCopyBytesGracefully(c, resp, encodeJson) return xaiResponse.Usage, nil } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 1a426d50..398bb08d 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/helper" "one-api/types" "strings" @@ -143,7 +144,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -218,20 +219,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + logger.SysError("error closing websocket connection: " + err.Error()) } break } @@ -282,6 +283,6 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } apiVersion = "v1.1" - common.SysLog("api_version not found, using default: " + apiVersion) + logger.SysLog("api_version not found, using default: " + apiVersion) return apiVersion } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 35882ed5..65b662b6 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -8,8 +8,10 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" + "one-api/service" "one-api/types" "strings" "sync" @@ -38,7 +40,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + logger.SysError("invalid zhipu key: " + apikey) return "" } @@ -186,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -195,13 +197,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage @@ -212,7 +214,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } @@ -222,7 +224,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) diff --git a/relay/relay-mj.go b/relay/chat_handler.go similarity index 98% rename from relay/relay-mj.go rename to relay/chat_handler.go index e7f316b9..30bce55c 100644 --- a/relay/relay-mj.go +++ b/relay/chat_handler.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -214,7 +215,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") @@ -300,7 +301,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - common.IOCopyBytesGracefully(c, nil, respBody) + service.IOCopyBytesGracefully(c, nil, respBody) return nil } @@ -521,7 +522,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) @@ -572,7 +573,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - common.SysError("get_channel_null: " + err.Error()) + logger.SysError("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/claude_handler.go b/relay/claude_handler.go index b4bf78ff..ddc424b4 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" @@ -18,68 +17,26 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { - textRequest = &dto.ClaudeRequest{} - err = c.ShouldBindJSON(textRequest) - if err != nil { - return nil, err - } - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - if textRequest.Model == "" { - return nil, errors.New("field model is required") - } - return textRequest, nil -} +func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoClaude(c) + textRequest, ok := info.Request.(*dto.ClaudeRequest) - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateClaudeRequest(c) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) } - if textRequest.Stream { - relayInfo.IsStream = true - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) if textRequest.MaxTokens == 0 { textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) @@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { textRequest.Temperature = common.GetPointer[float64](1.0) } textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") - relayInfo.UpstreamModelName = textRequest.Model + info.UpstreamModelName = textRequest.Model } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) //log.Printf("usage: %v", usage) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + + service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage)) return nil } - -func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5cd9223b..59be0011 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,10 +1,12 @@ package common import ( + "errors" "one-api/common" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" + "one-api/types" "strings" "time" @@ -33,17 +35,6 @@ type ClaudeConvertInfo struct { Done bool } -const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" - RelayFormatOpenAIResponses = "openai_responses" - RelayFormatOpenAIAudio = "openai_audio" - RelayFormatOpenAIImage = "openai_image" - RelayFormatRerank = "rerank" - RelayFormatEmbedding = "embedding" -) - type RerankerInfo struct { Documents []any ReturnDocuments bool @@ -59,61 +50,103 @@ type ResponsesUsageInfo struct { BuiltInTools map[string]*BuildInToolInfo } -type RelayInfo struct { +type ChannelMeta struct { ChannelType int ChannelId int - ChannelIsMultiKey bool // 是否多密钥 - ChannelMultiKeyIndex int // 多密钥索引 - TokenId int - TokenKey string - UserId int - UsingGroup string // 使用的分组 - UserGroup string // 用户所在分组 - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - isFirstResponse bool + ChannelIsMultiKey bool + ChannelMultiKeyIndex int + ChannelBaseUrl string + ApiType int + ApiVersion string + ApiKey string + Organization string + ChannelCreateTime int64 + ParamOverride map[string]interface{} + ChannelSetting dto.ChannelSettings + ChannelOtherSettings dto.ChannelOtherSettings + UpstreamModelName string + IsModelMapped bool +} + +type RelayInfo struct { + TokenId int + TokenKey string + UserId int + UsingGroup string // 使用的分组 + UserGroup string // 用户所在分组 + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + isFirstResponse bool //SendLastReasoningResponse bool - ApiType int IsStream bool IsGeminiBatchEmbedding bool IsPlayground bool UsePrice bool RelayMode int - UpstreamModelName string OriginModelName string //RecodeModelName string - RequestURLPath string - ApiVersion string - PromptTokens int - ApiKey string - Organization string - BaseUrl string - SupportStreamOptions bool - ShouldIncludeUsage bool - DisablePing bool // 是否禁止向下游发送自定义 Ping - IsModelMapped bool - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - ChannelSetting dto.ChannelSettings - ChannelOtherSettings dto.ChannelOtherSettings - ParamOverride map[string]interface{} - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat string - SendResponseCount int - ChannelCreateTime int64 + RequestURLPath string + PromptTokens int + SupportStreamOptions bool + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 + + PriceData types.PriceData + + Request dto.Request + ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo *ResponsesUsageInfo + *ChannelMeta +} + +func (info *RelayInfo) InitChannelMeta(c *gin.Context) { + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + apiType, _ := common.ChannelType2APIType(channelType) + channelMeta := &ChannelMeta{ + ChannelType: channelType, + ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId), + ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), + ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + IsModelMapped: false, + } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + channelMeta.ChannelSetting = channelSetting + } + + channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) + if ok { + channelMeta.ChannelOtherSettings = channelOtherSettings + } + info.ChannelMeta = channelMeta } // 定义支持流式选项的通道类型 @@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{ } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { - info := GenRelayInfo(c) + info := genBaseRelayInfo(c, nil) + info.RelayFormat = types.RelayFormatOpenAIRealtime info.ClientWs = ws info.InputAudioFormat = "pcm16" info.OutputAudioFormat = "pcm16" @@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } -func GenRelayInfoClaude(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatClaude +func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatClaude info.ShouldIncludeUsage = false info.ClaudeConvertInfo = &ClaudeConvertInfo{ LastMessagesType: LastMessageTypeNone, @@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { return info } -func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeRerank - info.RelayFormat = RelayFormatRerank + info.RelayFormat = types.RelayFormatRerank info.RerankerInfo = &RerankerInfo{ - Documents: req.Documents, - ReturnDocuments: req.GetReturnDocuments(), + Documents: request.Documents, + ReturnDocuments: request.GetReturnDocuments(), } return info } -func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIAudio +func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIAudio return info } -func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatEmbedding +func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatEmbedding return info } -func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeResponses - info.RelayFormat = RelayFormatOpenAIResponses + info.RelayFormat = types.RelayFormatOpenAIResponses info.SupportStreamOptions = false info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } - if len(req.Tools) > 0 { - for _, tool := range req.Tools { + if len(request.Tools) > 0 { + for _, tool := range request.Tools { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ ToolName: toolType, @@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel } } } - info.IsStream = req.Stream return info } -func GenRelayInfoGemini(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatGemini +func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatGemini info.ShouldIncludeUsage = false + return info } -func GenRelayInfoImage(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIImage +func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIImage return info } -func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) - channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) +func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAI + return info +} + +func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { + + //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) - tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) - tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) - userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) - tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) if startTime.IsZero() { startTime = time.Now() } + // firstResponseTime = time.Now() - 1 second - apiType, _ := common.ChannelType2APIType(channelType) - info := &RelayInfo{ - UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), - isFirstResponse: true, - RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - TokenKey: tokenKey, - UserId: userId, - UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), - UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), - TokenUnlimited: tokenUnlimited, + Request: request, + + UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), + + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), + + TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), + TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), + TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), + + isFirstResponse: true, + RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), + RequestURLPath: c.Request.URL.String(), + IsStream: request.IsStream(c), + StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), - Organization: c.GetString("channel_organization"), - - ChannelCreateTime: c.GetInt64("channel_create_time"), - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, }, - - ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), - ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") info.RequestURLPath = "/v1" + info.RequestURLPath } - if info.BaseUrl == "" { - info.BaseUrl = constant.ChannelBaseURLs[channelType] - } - if info.ChannelType == constant.ChannelTypeAzure { - info.ApiVersion = GetAPIVersion(c) - } - if info.ChannelType == constant.ChannelTypeVertexAi { - info.ApiVersion = c.GetString("region") - } - if streamSupportedChannels[info.ChannelType] { - info.SupportStreamOptions = true - } - - channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) - if ok { - info.ChannelSetting = channelSetting - } - - channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) - if ok { - info.ChannelOtherSettings = channelOtherSettings - } userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) if ok { @@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { return info } -func (info *RelayInfo) SetPromptTokens(promptTokens int) { - info.PromptTokens = promptTokens +func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { + switch relayFormat { + case types.RelayFormatOpenAI: + return GenRelayInfoOpenAI(c, request), nil + case types.RelayFormatOpenAIAudio: + return GenRelayInfoOpenAIAudio(c, request), nil + case types.RelayFormatOpenAIImage: + return GenRelayInfoImage(c, request), nil + case types.RelayFormatOpenAIRealtime: + return GenRelayInfoWs(c, ws), nil + case types.RelayFormatClaude: + return GenRelayInfoClaude(c, request), nil + case types.RelayFormatRerank: + if request, ok := request.(*dto.RerankRequest); ok { + return GenRelayInfoRerank(c, request), nil + } + return nil, errors.New("request is not a RerankRequest") + case types.RelayFormatGemini: + return GenRelayInfoGemini(c, request), nil + case types.RelayFormatEmbedding: + return GenRelayInfoEmbedding(c, request), nil + case types.RelayFormatOpenAIResponses: + if request, ok := request.(*dto.OpenAIResponsesRequest); ok { + return GenRelayInfoResponses(c, request), nil + } + return nil, errors.New("request is not a OpenAIResponsesRequest") + default: + return nil, errors.New("invalid relay format") + } } -func (info *RelayInfo) SetIsStream(isStream bool) { - info.IsStream = isStream +func (info *RelayInfo) SetPromptTokens(promptTokens int) { + info.PromptTokens = promptTokens } func (info *RelayInfo) SetFirstResponseTime() { diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 57df5fe3..05dbfa6d 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println("reranker response body: ", string(responseBody)) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fef8d2c9..f7906cf9 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/types" @@ -16,69 +15,27 @@ import ( "github.com/gin-gonic/gin" ) -func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) - return token -} +func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error { - if embeddingRequest.Input == nil { - return fmt.Errorf("input is empty") - } - if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { - embeddingRequest.Model = "omni-moderation-latest" - } - if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { - embeddingRequest.Model = c.Param("model") - } - return nil -} + info.InitChannelMeta(c) -func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoEmbedding(c) - - var embeddingRequest *dto.EmbeddingRequest - err := common.UnmarshalBodyReusable(c, &embeddingRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) } - err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) + err := helper.ModelMappedHelper(c, info, embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getEmbeddingPromptToken(*embeddingRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e0581156..3ebe0884 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,17 +2,16 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,64 +19,6 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if len(request.Contents) == 0 { - return nil, errors.New("contents is required") - } - return request, nil -} - -// 流模式 -// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx -func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { - if c.Query("alt") == "sse" { - relayInfo.IsStream = true - } - - // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { - // relayInfo.IsStream = true - // } -} - -func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) { - var inputTexts []string - for _, content := range textRequest.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - if len(inputTexts) == 0 { - return nil, nil - } - - sensitiveWords, err := service.CheckSensitiveInput(inputTexts) - return sensitiveWords, err -} - -func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int { - // 计算输入 token 数量 - var inputTexts []string - for _, content := range req.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - - inputText := strings.Join(inputTexts, "\n") - inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) - info.PromptTokens = inputTokens - return inputTokens -} - func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget @@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string { return modelName } -func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateGeminiRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } +func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoGemini(c) - - // 检查 Gemini 流式模式 - checkGeminiStreamMode(c, relayInfo) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } + request, ok := info.Request.(*dto.GeminiChatRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getGeminiInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if isNoThinkingRequest(req) { + if isNoThinkingRequest(request) { // check is thinking - if !strings.Contains(relayInfo.OriginModelName, "-nothinking") { + if !strings.Contains(info.OriginModelName, "-nothinking") { // try to get no thinking model price - noThinkingModelName := relayInfo.OriginModelName + "-nothinking" + noThinkingModelName := info.OriginModelName + "-nothinking" containPrice := helper.ContainPriceOrRatio(noThinkingModelName) if containPrice { - relayInfo.OriginModelName = noThinkingModelName - relayInfo.UpstreamModelName = noThinkingModelName + info.OriginModelName = noThinkingModelName + info.UpstreamModelName = noThinkingModelName } } } - if req.GenerationConfig.ThinkingConfig == nil { - gemini.ThinkingAdaptor(req, relayInfo) + if request.GenerationConfig.ThinkingConfig == nil { + gemini.ThinkingAdaptor(request, info) } } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) // Clean up empty system instruction - if req.SystemInstructions != nil { + if request.SystemInstructions != nil { hasContent := false - for _, part := range req.SystemInstructions.Parts { + for _, part := range request.SystemInstructions.Parts { if part.Text != "" { hasContent = true break } } if !hasContent { - req.SystemInstructions = nil + request.SystemInstructions = nil } } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { requestBody = bytes.NewReader(body) } else { // 使用 ConvertGeminiRequest 转换请求格式 - convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req) + convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("Gemini request body: %s", string(jsonData)) - } + logger.LogDebug(c, "Gemini request body: "+string(jsonData)) + requestBody = bytes.NewReader(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } -func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoGemini(c) +func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") - relayInfo.IsGeminiBatchEmbedding = isBatch + info.IsGeminiBatchEmbedding = isBatch - var promptTokens int var req any var err error var inputTexts []string @@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } } - promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName) - relayInfo.SetPromptTokens(promptTokens) - c.Set("prompt_tokens", promptTokens) - err = helper.ModelMappedHelper(c, relayInfo, req) + err = helper.ModelMappedHelper(c, info, req) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader jsonData, err := common.Marshal(req) @@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(jsonData) - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index c8edb798..5075314d 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "github.com/gin-gonic/gin" @@ -100,7 +101,7 @@ func Done(c *gin.Context) { func WssString(c *gin.Context, ws *websocket.Conn, str string) error { if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) @@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { return fmt.Errorf("error marshalling object: %w", err) } if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index c1735149..e894e228 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,9 +4,10 @@ import ( "encoding/json" "errors" "fmt" - common2 "one-api/common" "one-api/dto" + common2 "one-api/logger" "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro } if request != nil { switch info.RelayFormat { - case common.RelayFormatGemini: + case types.RelayFormatGemini: // Gemini 模型映射 - case common.RelayFormatClaude: + case types.RelayFormatClaude: if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { claudeRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIResponses: + case types.RelayFormatOpenAIResponses: if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { openAIResponsesRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIAudio: + case types.RelayFormatOpenAIAudio: if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { openAIAudioRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIImage: + case types.RelayFormatOpenAIImage: if imageRequest, ok := request.(*dto.ImageRequest); ok { imageRequest.Model = info.UpstreamModelName } - case common.RelayFormatRerank: + case types.RelayFormatRerank: if rerankRequest, ok := request.(*dto.RerankRequest); ok { rerankRequest.Model = info.UpstreamModelName } - case common.RelayFormatEmbedding: + case types.RelayFormatEmbedding: if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { embeddingRequest.Model = info.UpstreamModelName } diff --git a/relay/helper/price.go b/relay/helper/price.go index e80578e5..89fc3b66 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,35 +5,14 @@ import ( "one-api/common" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" + "one-api/types" "github.com/gin-gonic/gin" ) -type GroupRatioInfo struct { - GroupRatio float64 - GroupSpecialRatio float64 - HasSpecialRatio bool -} - -type PriceData struct { - ModelPrice float64 - ModelRatio float64 - CompletionRatio float64 - CacheRatio float64 - CacheCreationRatio float64 - ImageRatio float64 - UsePrice bool - ShouldPreConsumedQuota int - GroupRatioInfo GroupRatioInfo -} - -func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) -} - // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present -func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { - groupRatioInfo := GroupRatioInfo{ +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo { + groupRatioInfo := types.GroupRatioInfo{ GroupRatio: 1.0, // default ratio GroupSpecialRatio: -1, } @@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR return groupRatioInfo } -func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) { modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false) groupRatioInfo := HandleGroupRatio(c, info) @@ -75,8 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var cacheCreationRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota - if maxTokens != 0 { - preConsumedTokens = promptTokens + maxTokens + if meta.MaxTokens != 0 { + preConsumedTokens = promptTokens + meta.MaxTokens } var success bool var matchName string @@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens acceptUnsetRatio = true } if !acceptUnsetRatio { - return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) + return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) } } completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) @@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { + if meta.ImagePriceRatio != 0 { + modelPrice = modelPrice * meta.ImagePriceRatio + } preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } - priceData := PriceData{ + priceData := types.PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, @@ -115,38 +97,32 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens if common.DebugEnabled { println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting())) } - + info.PriceData = priceData return priceData, nil } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { - groupRatioInfo := HandleGroupRatio(c, info) - - modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := PerCallPriceData{ - ModelPrice: modelPrice, - Quota: quota, - GroupRatioInfo: groupRatioInfo, - } - return priceData -} +//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +// groupRatioInfo := HandleGroupRatio(c, info) +// +// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) +// // 如果没有配置价格,则使用默认价格 +// if !success { +// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] +// if !ok { +// modelPrice = 0.1 +// } else { +// modelPrice = defaultPrice +// } +// } +// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) +// priceData := types.PerCallPriceData{ +// ModelPrice: modelPrice, +// Quota: quota, +// GroupRatioInfo: groupRatioInfo, +// } +// return priceData +//} func ContainPriceOrRatio(modelName string) bool { _, ok := ratio_setting.GetModelPrice(modelName, false) diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a5706f95..725d178c 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/setting/operation_setting" "strings" @@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-done: case <-time.After(5 * time.Second): - common.LogError(c, "timeout waiting for goroutines to exit") + logger.LogError(c, "timeout waiting for goroutines to exit") } close(stopChan) @@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) common.SafeSendBool(stopChan, true) } if common.DebugEnabled { @@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case err := <-done: if err != nil { - common.LogError(c, "ping data error: "+err.Error()) + logger.LogError(c, "ping data error: "+err.Error()) return } if common.DebugEnabled { println("ping data sent") } case <-time.After(10 * time.Second): - common.LogError(c, "ping data send timeout") + logger.LogError(c, "ping data send timeout") return case <-ctx.Done(): return @@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon // 监听客户端断开连接 return case <-pingTimeout.C: - common.LogError(c, "ping goroutine max duration reached") + logger.LogError(c, "ping goroutine max duration reached") return } } @@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) } common.SafeSendBool(stopChan, true) if common.DebugEnabled { @@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } case <-time.After(10 * time.Second): - common.LogError(c, "data handler timeout") + logger.LogError(c, "data handler timeout") return case <-ctx.Done(): return @@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon if err := scanner.Err(); err != nil { if err != io.EOF { - common.LogError(c, "scanner error: "+err.Error()) + logger.LogError(c, "scanner error: "+err.Error()) } } }) @@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-ticker.C: // 超时处理逻辑 - common.LogError(c, "streaming timeout") + logger.LogError(c, "streaming timeout") case <-stopChan: // 正常结束 - common.LogInfo(c, "streaming finished") + logger.LogInfo(c, "streaming finished") case <-c.Request.Context().Done(): // 客户端断开连接 - common.LogInfo(c, "client disconnected") + logger.LogInfo(c, "client disconnected") } } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go new file mode 100644 index 00000000..0bc51774 --- /dev/null +++ b/relay/helper/valid_request.go @@ -0,0 +1,301 @@ +package helper + +import ( + "errors" + "fmt" + "math" + "one-api/common" + "one-api/dto" + "one-api/logger" + relayconstant "one-api/relay/constant" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) + + switch format { + case types.RelayFormatOpenAI: + request, err = GetAndValidateTextRequest(c, relayMode) + case types.RelayFormatGemini: + request, err = GetAndValidateGeminiRequest(c) + case types.RelayFormatClaude: + request, err = GetAndValidateClaudeRequest(c) + case types.RelayFormatOpenAIResponses: + request, err = GetAndValidateResponsesRequest(c) + + case types.RelayFormatOpenAIImage: + request, err = GetAndValidOpenAIImageRequest(c, relayMode) + case types.RelayFormatEmbedding: + request, err = GetAndValidateEmbeddingRequest(c, relayMode) + case types.RelayFormatRerank: + request, err = GetAndValidateRerankRequest(c) + case types.RelayFormatOpenAIAudio: + request, err = GetAndValidAudioRequest(c, relayMode) + case types.RelayFormatOpenAIRealtime: + // nothing to do, no request body + default: + return nil, fmt.Errorf("unsupported relay format: %s", format) + } + return request, err +} + +func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err + } + switch relayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + default: + err = c.Request.ParseForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + if audioRequest.Model == "" { + audioRequest.Model = formData.Get("model") + } + + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + audioRequest.ResponseFormat = formData.Get("response_format") + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { + var rerankRequest *dto.RerankRequest + err := common.UnmarshalBodyReusable(c, &rerankRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if rerankRequest.Query == "" { + return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + if len(rerankRequest.Documents) == 0 { + return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + return rerankRequest, nil +} + +func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if embeddingRequest.Input == nil { + return nil, fmt.Errorf("input is empty") + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "omni-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + return embeddingRequest, nil +} + +func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if request.Input == nil { + return nil, errors.New("input is required") + } + return request, nil +} + +func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + + switch relayMode { + case relayconstant.RelayModeImagesEdits: + _, err := c.MultipartForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = uint(common.String2Int(formData.Get("n"))) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") + + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + + watermark := formData.Has("watermark") + if watermark { + imageRequest.Watermark = &watermark + } + default: + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-3" + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + + if imageRequest.N == 0 { + imageRequest.N = 1 + } + } + + return imageRequest, nil +} + +func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + + //if textRequest.Stream { + // relayInfo.IsStream = true + //} + + return textRequest, nil +} + +func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { + textRequest := &dto.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + + if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + + if textRequest.MaxTokens > math.MaxInt32/2 { + return nil, errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return nil, errors.New("model is required") + } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } + switch relayMode { + case relayconstant.RelayModeCompletions: + if textRequest.Prompt == "" { + return nil, errors.New("field prompt is required") + } + case relayconstant.RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + case relayconstant.RelayModeEmbeddings: + case relayconstant.RelayModeModerations: + if textRequest.Input == nil || textRequest.Input == "" { + return nil, errors.New("field input is required") + } + case relayconstant.RelayModeEdits: + if textRequest.Instruction == "" { + return nil, errors.New("field instruction is required") + } + } + return textRequest, nil +} + +func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { + + request := &dto.GeminiChatRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + + //if c.Query("alt") == "sse" { + // relayInfo.IsStream = true + //} + + return request, nil +} diff --git a/relay/image_handler.go b/relay/image_handler.go index f0b69699..008a979d 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -3,19 +3,15 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" - "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -23,183 +19,41 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { - imageRequest := &dto.ImageRequest{} +func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - switch info.RelayMode { - case relayconstant.RelayModeImagesEdits: - _, err := c.MultipartForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - imageRequest.Prompt = formData.Get("prompt") - imageRequest.Model = formData.Get("model") - imageRequest.N = common.String2Int(formData.Get("n")) - imageRequest.Quality = formData.Get("quality") - imageRequest.Size = formData.Get("size") + info.InitChannelMeta(c) - if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } + imageRequest, ok := info.Request.(*dto.ImageRequest) - if info.ApiType == constant.APITypeVolcEngine { - watermark := formData.Has("watermark") - imageRequest.Watermark = &watermark - } - default: - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" - } - - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "auto" - } - } - - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request)) } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - return imageRequest, nil -} - -func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoImage(c) - - imageRequest, err := getAndValidImageRequest(c, relayInfo) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, imageRequest) + err := helper.ModelMappedHelper(c, info, imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - var preConsumedQuota int - var quota int - var userQuota int - if !priceData.UsePrice { - // modelRatio 16 = modelPrice $0.04 - // per 1 modelRatio = $0.04 / 16 - // priceData.ModelPrice = 0.0025 * priceData.ModelRatio - preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - } else { - sizeRatio := 1.0 - qualityRatio := 1.0 - - if strings.HasPrefix(imageRequest.Model, "dall-e") { - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 0.4 - } else if imageRequest.Size == "512x512" { - sizeRatio = 0.45 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1 - } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2 - } - - if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { - qualityRatio = 2.0 - if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - qualityRatio = 1.5 - } - } - } - - // reset model price - priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) - userQuota, err = model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) - } - } - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) + convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { + if info.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) @@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) @@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } if usage.(*dto.Usage).TotalTokens == 0 { - usage.(*dto.Usage).TotalTokens = imageRequest.N + usage.(*dto.Usage).TotalTokens = int(imageRequest.N) } if usage.(*dto.Usage).PromptTokens == 0 { - usage.(*dto.Usage).PromptTokens = imageRequest.N + usage.(*dto.Usage).PromptTokens = int(imageRequest.N) } + quality := "standard" if imageRequest.Quality == "hd" { quality = "hd" } - logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent) + var logContent string + + if len(imageRequest.Size) > 0 { + logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) + } + + postConsumeQuota(c, info, usage.(*dto.Usage), logContent) return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 50d574f3..de750e76 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -2,172 +2,56 @@ package relay import ( "bytes" - "errors" "fmt" "io" - "math" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/setting/operation_setting" "one-api/types" "strings" "time" - "github.com/bytedance/gopkg/util/gopool" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) -func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { - textRequest := &dto.GeneralOpenAIRequest{} - err := common.UnmarshalBodyReusable(c, textRequest) - if err != nil { - return nil, err - } - if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } +func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - if textRequest.MaxTokens > math.MaxInt32/2 { - return nil, errors.New("max_tokens is invalid") - } - if textRequest.Model == "" { - return nil, errors.New("model is required") - } - if textRequest.WebSearchOptions != nil { - if textRequest.WebSearchOptions.SearchContextSize != "" { - validSizes := map[string]bool{ - "high": true, - "medium": true, - "low": true, - } - if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { - return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") - } - } else { - textRequest.WebSearchOptions.SearchContextSize = "medium" - } - } - switch relayInfo.RelayMode { - case relayconstant.RelayModeCompletions: - if textRequest.Prompt == "" { - return nil, errors.New("field prompt is required") - } - case relayconstant.RelayModeChatCompletions: - if len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - case relayconstant.RelayModeEmbeddings: - case relayconstant.RelayModeModerations: - if textRequest.Input == nil || textRequest.Input == "" { - return nil, errors.New("field input is required") - } - case relayconstant.RelayModeEdits: - if textRequest.Instruction == "" { - return nil, errors.New("field instruction is required") - } - } - relayInfo.IsStream = textRequest.Stream - return textRequest, nil -} + info.InitChannelMeta(c) -func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest) - relayInfo := relaycommon.GenRelayInfo(c) - - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + //return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request) } if textRequest.WebSearchOptions != nil { c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) } - if setting.ShouldCheckPromptSensitive() { - words, err := checkRequestSensitive(textRequest, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - // 获取 promptTokens,如果上下文中已经存在,则直接使用 - var promptTokens int - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens = value.(int) - relayInfo.PromptTokens = promptTokens - } else { - promptTokens, err = getPromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newApiErr != nil { - return newApiErr - } - defer func() { - if newApiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - includeUsage := true - // 判断用户是否需要返回使用情况 - if textRequest.StreamOptions != nil { - includeUsage = textRequest.StreamOptions.IncludeUsage - } - - // 如果不支持StreamOptions,将StreamOptions设置为nil - if !relayInfo.SupportStreamOptions || !textRequest.Stream { - textRequest.StreamOptions = nil - } else { - // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions - if constant.ForceStreamOption { - textRequest.StreamOptions = &dto.StreamOptions{ - IncludeUsage: true, - } - } - } - - relayInfo.ShouldIncludeUsage = includeUsage - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -177,12 +61,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.ChannelSetting.SystemPrompt != "" { + if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 request := convertedRequest.(*dto.GeneralOpenAIRequest) containSystemPrompt := false @@ -196,22 +80,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // 如果没有系统提示,则添加系统提示 systemMessage := dto.Message{ Role: request.GetSystemRoleName(), - Content: relayInfo.ChannelSetting.SystemPrompt, + Content: info.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) - } else if relayInfo.ChannelSetting.SystemPromptOverride { + } else if info.ChannelSetting.SystemPromptOverride { common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) // 如果有系统提示,且允许覆盖,则拼接到前面 for i, message := range request.Messages { if message.Role == request.GetSystemRoleName() { if message.IsStringContent() { - request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) } else { contents := message.ParseContent() contents = append([]dto.MediaContent{ { Type: dto.ContentTypeText, - Text: relayInfo.ChannelSetting.SystemPrompt, + Text: info.ChannelSetting.SystemPrompt, }, }, contents...) request.Messages[i].Content = contents @@ -228,10 +112,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -240,14 +124,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("requestBody: ", string(jsonData)) - } + logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) + requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -256,125 +139,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr = service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } - usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenChatRequest(info, *textRequest) - case relayconstant.RelayModeCompletions: - promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) - case relayconstant.RelayModeModerations: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - case relayconstant.RelayModeEmbeddings: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - default: - err = errors.New("unknown relay mode") - promptTokens = 0 - } - info.PromptTokens = promptTokens - return promptTokens, err -} - -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { - var err error - var words []string - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - words, err = service.CheckSensitiveMessages(textRequest.Messages) - case relayconstant.RelayModeCompletions: - words, err = service.CheckSensitiveInput(textRequest.Prompt) - case relayconstant.RelayModeModerations: - words, err = service.CheckSensitiveInput(textRequest.Input) - case relayconstant.RelayModeEmbeddings: - words, err = service.CheckSensitiveInput(textRequest.Input) - } - return words, err -} - -// 预扣费并返回用户剩余配额 -func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - relayInfo.UserQuota = userQuota - if userQuota > 100*preConsumedQuota { - // 用户额度充足,判断令牌额度是否充足 - if !relayInfo.TokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) - } - } - - if preConsumedQuota > 0 { - err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) - } - } - return preConsumedQuota, userQuota, nil -} - -func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { - if preConsumedQuota != 0 { - gopool.Go(func() { - relayInfoCopy := *relayInfo - - err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) - if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) - } - }) - } -} - -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -392,12 +181,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - cacheRatio := priceData.CacheRatio - imageRatio := priceData.ImageRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice + completionRatio := relayInfo.PriceData.CompletionRatio + cacheRatio := relayInfo.PriceData.CacheRatio + imageRatio := relayInfo.PriceData.ImageRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -470,7 +259,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, var audioInputQuota decimal.Decimal var audioInputPrice float64 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens var cachedTokensWithRatio decimal.Decimal @@ -518,7 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, totalTokens := promptTokens + completionTokens var logContent string - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) @@ -530,8 +319,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { if !ratio.IsZero() && quota == 0 { quota = 1 @@ -540,11 +329,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -560,7 +349,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio @@ -604,7 +393,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, diff --git a/relay/relay_task.go b/relay/relay_task.go index 0ccc3b33..ae002d73 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -127,7 +128,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 1e547e2a..85e4f174 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) { +func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - var rerankRequest *dto.RerankRequest - err := common.UnmarshalBodyReusable(c, &rerankRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + rerankRequest, ok := info.Request.(*dto.RerankRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) - - if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) + err := helper.ModelMappedHelper(c, info, rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getRerankPromptToken(*rerankRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) + convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError requestBody = bytes.NewBuffer(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 65c240b2..cd80da33 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -3,7 +3,6 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -12,7 +11,6 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,82 +18,24 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { - request := &dto.OpenAIResponsesRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if request.Model == "" { - return nil, errors.New("model is required") - } - if len(request.Input) == 0 { - return nil, errors.New("input is required") - } - return request, nil +func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) -} - -func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) - return sensitiveWords, err -} - -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { - inputTokens := service.CountTokenInput(req.Input, req.Model) - info.PromptTokens = inputTokens - return inputTokens -} - -func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateResponsesRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + request, ok := info.Request.(*dto.OpenAIResponsesRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoResponses(c, req) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkInputSensitive(req, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) @@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) @@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } diff --git a/relay/websocket.go b/relay/websocket.go index 3715b237..22b681f1 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -15,13 +15,6 @@ import ( func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) - // get & validate textRequest 获取并验证文本请求 - //realtimeEvent, err := getAndValidateWssRequest(c, ws) - //if err != nil { - // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error())) - // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) - //} - err := helper.ModelMappedHelper(c, relayInfo, nil) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) diff --git a/router/main.go b/router/main.go index 0d2bfdce..7653f3a5 100644 --- a/router/main.go +++ b/router/main.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/logger" "os" "strings" ) @@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - common.SysLog("FRONTEND_BASE_URL is ignored on master node") + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS, indexPage) diff --git a/router/relay-router.go b/router/relay-router.go index cd656580..e0f05e97 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,11 +1,13 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/relay" + "one-api/types" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -62,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) { - // WebSocket 路由 + // WebSocket 路由(统一到 Relay) wsRouter := relayV1Router.Group("") wsRouter.Use(middleware.Distribute()) - wsRouter.GET("/realtime", controller.WssRelay) + wsRouter.GET("/realtime", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIRealtime) + }) } { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) - httpRouter.POST("/messages", controller.RelayClaude) - httpRouter.POST("/completions", controller.Relay) - httpRouter.POST("/chat/completions", controller.Relay) - httpRouter.POST("/edits", controller.Relay) - httpRouter.POST("/images/generations", controller.Relay) - httpRouter.POST("/images/edits", controller.Relay) + + // claude related routes + httpRouter.POST("/messages", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatClaude) + }) + + // chat related routes + httpRouter.POST("/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + httpRouter.POST("/chat/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // response related routes + httpRouter.POST("/responses", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIResponses) + }) + + // image related routes + httpRouter.POST("/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/generations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + + // embedding related routes + httpRouter.POST("/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatEmbedding) + }) + + // audio related routes + httpRouter.POST("/audio/transcriptions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/translations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/speech", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + + // rerank related routes + httpRouter.POST("/rerank", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatRerank) + }) + + // gemini relay routes + httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + httpRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + + // other relay routes + httpRouter.POST("/moderations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // not implemented httpRouter.POST("/images/variations", controller.RelayNotImplemented) - httpRouter.POST("/embeddings", controller.Relay) - httpRouter.POST("/engines/:model/embeddings", controller.Relay) - httpRouter.POST("/audio/transcriptions", controller.Relay) - httpRouter.POST("/audio/translations", controller.Relay) - httpRouter.POST("/audio/speech", controller.Relay) - httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -95,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) - httpRouter.POST("/moderations", controller.Relay) - httpRouter.POST("/rerank", controller.Relay) - httpRouter.POST("/models/*path", controller.Relay) } relayMjRouter := router.Group("/mj") @@ -121,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) { relayGeminiRouter.Use(middleware.Distribute()) { // Gemini API 路径格式: /v1beta/models/{model_name}:{action} - relayGeminiRouter.POST("/models/*path", controller.Relay) + relayGeminiRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) } } diff --git a/service/cf_worker.go b/service/cf_worker.go index ae6e1ffe..65f7f133 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "net/http" - "one-api/common" + "one-api/logger" "one-api/setting" "strings" ) @@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) } } diff --git a/service/error.go b/service/error.go index 9672402d..668731b0 100644 --- a/service/error.go +++ b/service/error.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "strconv" "strings" @@ -58,7 +59,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError lowerText := strings.ToLower(text) if !strings.HasPrefix(lowerText, "get file base64 from url") { if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) + logger.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } } @@ -85,7 +86,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if err != nil { return } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse err = common.Unmarshal(responseBody, &errResponse) @@ -138,7 +139,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { text := err.Error() lowerText := strings.ToLower(text) if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) + logger.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } //避免暴露内部错误 diff --git a/common/http.go b/service/http.go similarity index 86% rename from common/http.go rename to service/http.go index d2e824ef..357a2e78 100644 --- a/common/http.go +++ b/service/http.go @@ -1,10 +1,12 @@ -package common +package service import ( "bytes" "fmt" "io" "net/http" + "one-api/common" + "one-api/logger" "github.com/gin-gonic/gin" ) @@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) { } err := httpResponse.Body.Close() if err != nil { - SysError("failed to close response body: " + err.Error()) + common.SysError("failed to close response body: " + err.Error()) } } @@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { _, err := io.Copy(c.Writer, body) if err != nil { - LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) + logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) } } diff --git a/service/image.go b/service/image.go index 252093f1..957ca041 100644 --- a/service/image.go +++ b/service/image.go @@ -8,8 +8,8 @@ import ( "image" "io" "net/http" - "one-api/common" "one-api/constant" + "one-api/logger" "strings" "golang.org/x/image/webp" @@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { response, err := DoDownloadRequest(imageUrl) if err != nil { - common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err } defer response.Body.Close() @@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { - common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) + logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) // 从response.Body读取更多的数据直到达到当前的限制 additionalData := make([]byte, limit-int64(len(readData))) @@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) { config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - common.SysLog(err.Error()) + logger.SysLog(err.Error()) config, err = webp.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - common.SysLog(err.Error()) + logger.SysLog(err.Error()) } format = "webp" } diff --git a/service/midjourney.go b/service/midjourney.go index 1fc19682..1d232739 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relayconstant "one-api/relay/constant" "one-api/setting" "strconv" @@ -212,7 +213,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - common.SysError("do request failed: " + err.Error()) + logger.SysError("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode @@ -233,7 +234,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) respStr := string(responseBody) log.Printf("respStr: %s", respStr) if respStr == "" { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go new file mode 100644 index 00000000..3c4d0e7e --- /dev/null +++ b/service/pre_consume_quota.go @@ -0,0 +1,72 @@ +package service + +import ( + "errors" + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "net/http" + "one-api/logger" + "one-api/model" + relaycommon "one-api/relay/common" + "one-api/types" +) + +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { + if preConsumedQuota != 0 { + gopool.Go(func() { + relayInfoCopy := *relayInfo + + err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + if err != nil { + logger.SysError("error return pre-consumed quota: " + err.Error()) + } + }) + } +} + +// PreConsumeQuota checks if the user has enough quota to pre-consume. +// It returns the pre-consumed quota if successful, or an error if not. +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + } + if userQuota <= 0 { + return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if userQuota-preConsumedQuota < 0 { + return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + relayInfo.UserQuota = userQuota + if userQuota > 100*preConsumedQuota { + // 用户额度充足,判断令牌额度是否充足 + if !relayInfo.TokenUnlimited { + // 非无限令牌,判断令牌额度是否充足 + tokenQuota := c.GetInt("token_quota") + if tokenQuota > 100*preConsumedQuota { + // 令牌额度充足,信任令牌 + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) + } + } else { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota))) + } + } + + if preConsumedQuota > 0 { + err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) + if err != nil { + return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + } + relayInfo.FinalPreConsumedQuota = preConsumedQuota + return preConsumedQuota, nil +} diff --git a/service/quota.go b/service/quota.go index 0f618402..d6f49d64 100644 --- a/service/quota.go +++ b/service/quota.go @@ -8,11 +8,12 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strings" "time" @@ -129,23 +130,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota := calculateAudioQuota(quotaInfo) if userQuota < quota { - return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)) + return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota)) } if !token.UnlimitedQuota && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } - common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) + logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) return nil } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + usage *dto.RealtimeUsage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -159,10 +160,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -196,8 +197,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) @@ -208,7 +209,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.InputTokens, @@ -218,7 +219,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -226,8 +226,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod }) } -func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -235,20 +234,20 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - cacheRatio := priceData.CacheRatio + completionRatio := relayInfo.PriceData.CompletionRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + cacheRatio := relayInfo.PriceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens - cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { promptTokens -= cacheTokens - if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 { - maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) + if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 { + maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData) if promptTokens >= maybeCacheCreationTokens { cacheCreationTokens = maybeCacheCreationTokens } @@ -257,7 +256,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } calculateQuota := 0.0 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { calculateQuota = float64(promptTokens) calculateQuota += float64(cacheTokens) * cacheRatio calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio @@ -282,23 +281,23 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游出错)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, @@ -308,7 +307,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -317,7 +315,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } -func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { +func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int { if priceData.CacheCreationRatio == 1 { return 0 } @@ -338,8 +336,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData (promptCacheCreatePrice - quotaPrice))) } -func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -353,10 +350,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -390,18 +387,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -410,7 +407,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.PromptTokens, @@ -420,7 +417,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -443,7 +439,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { return err } if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) if err != nil { @@ -501,7 +497,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) if err != nil { common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } diff --git a/service/token_counter.go b/service/token_counter.go index eed5b5ca..ec817182 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,18 +4,22 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tiktoken-go/tokenizer" - "github.com/tiktoken-go/tokenizer/codec" "image" "log" "math" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" + "one-api/types" "strings" "sync" "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" ) // tokenEncoderMap won't grow after initialization @@ -28,9 +32,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec) var tokenEncoderMutex sync.RWMutex func InitTokenEncoders() { - common.SysLog("initializing token encoders") + logger.SysLog("initializing token encoders") defaultTokenEncoder = codec.NewCl100kBase() - common.SysLog("token encoders initialized") + logger.SysLog("token encoders initialized") } func getTokenEncoder(model string) tokenizer.Codec { @@ -72,52 +76,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { return tkm } -func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - if imageUrl == nil { +func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { + if fileMeta == nil { return 0, fmt.Errorf("image_url_is_nil") } + + // Defaults for 4o/4.1/4.5 family unless overridden below baseTokens := 85 - if model == "glm-4v" { + tileTokens := 170 + + // Model classification + lowerModel := strings.ToLower(model) + + // Special cases from existing behavior + if strings.HasPrefix(lowerModel, "glm-4") { return 1047, nil } - if imageUrl.Detail == "low" { + + // Patch-based models (32x32 patches, capped at 1536, with multiplier) + isPatchBased := false + multiplier := 1.0 + switch { + case strings.Contains(lowerModel, "gpt-4.1-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.Contains(lowerModel, "gpt-4.1-nano"): + isPatchBased = true + multiplier = 2.46 + case strings.HasPrefix(lowerModel, "o4-mini"): + isPatchBased = true + multiplier = 1.72 + case strings.HasPrefix(lowerModel, "gpt-5-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.HasPrefix(lowerModel, "gpt-5-nano"): + isPatchBased = true + multiplier = 2.46 + } + + // Tile-based model tokens and bases per doc + if !isPatchBased { + if strings.HasPrefix(lowerModel, "gpt-4o-mini") { + baseTokens = 2833 + tileTokens = 5667 + } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) { + baseTokens = 70 + tileTokens = 140 + } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") { + baseTokens = 75 + tileTokens = 150 + } else if strings.Contains(lowerModel, "computer-use-preview") { + baseTokens = 65 + tileTokens = 129 + } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") { + baseTokens = 85 + tileTokens = 170 + } + } + + // Respect existing feature flags/short-circuits + if fileMeta.Detail == "low" && !isPatchBased { return baseTokens, nil } if !constant.GetMediaTokenNotStream && !stream { return 3 * baseTokens, nil } - - // 同步One API的图片计费逻辑 - if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - imageUrl.Detail = "high" + // Normalize detail + if fileMeta.Detail == "auto" || fileMeta.Detail == "" { + fileMeta.Detail = "high" } - - tileTokens := 170 - if strings.HasPrefix(model, "gpt-4o-mini") { - tileTokens = 5667 - baseTokens = 2833 - } - // 是否统计图片token + // Whether to count image tokens at all if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { - return 3 * baseTokens, nil - } + + // Decode image to get dimensions var config image.Config var err error var format string var b64str string - if strings.HasPrefix(imageUrl.Url, "http") { - config, format, err = DecodeUrlImageData(imageUrl.Url) + if strings.HasPrefix(fileMeta.Data, "http") { + config, format, err = DecodeUrlImageData(fileMeta.Data) } else { - common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) + logger.SysLog(fmt.Sprintf("decoding image")) + config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) } if err != nil { return 0, err } - imageUrl.MimeType = format + fileMeta.MimeType = format if config.Width == 0 || config.Height == 0 { // not an image @@ -125,60 +172,144 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data)) } - shortSide := config.Width - otherSide := config.Height - log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height) - // 缩放倍数 - scale := 1.0 - if config.Height < shortSide { - shortSide = config.Height - otherSide = config.Width + width := config.Width + height := config.Height + log.Printf("format: %s, width: %d, height: %d", format, width, height) + + if isPatchBased { + // 32x32 patch-based calculation with 1536 cap and model multiplier + ceilDiv := func(a, b int) int { return (a + b - 1) / b } + rawPatchesW := ceilDiv(width, 32) + rawPatchesH := ceilDiv(height, 32) + rawPatches := rawPatchesW * rawPatchesH + if rawPatches > 1536 { + // scale down + area := float64(width * height) + r := math.Sqrt(float64(32*32*1536) / area) + wScaled := float64(width) * r + hScaled := float64(height) * r + // adjust to fit whole number of patches after scaling + adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0) + adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0) + adj := math.Min(adjW, adjH) + if !math.IsNaN(adj) && adj > 0 { + r = r * adj + } + wScaled = float64(width) * r + hScaled = float64(height) * r + patchesW := math.Ceil(wScaled / 32.0) + patchesH := math.Ceil(hScaled / 32.0) + imageTokens := int(patchesW * patchesH) + if imageTokens > 1536 { + imageTokens = 1536 + } + return int(math.Round(float64(imageTokens) * multiplier)), nil + } + // below cap + imageTokens := rawPatches + return int(math.Round(float64(imageTokens) * multiplier)), nil } - // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 - if shortSide > 768 { - scale = float64(shortSide) / 768 - shortSide = 768 + // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc. + // Step 1: fit within 2048x2048 square + maxSide := math.Max(float64(width), float64(height)) + fitScale := 1.0 + if maxSide > 2048 { + fitScale = maxSide / 2048.0 } - // 将另一边按照相同的比例缩小,向上取整 - otherSide = int(math.Ceil(float64(otherSide) / scale)) - log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) - // 计算图片的token数量(边的长度除以512,向上取整) - tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) - log.Printf("tiles: %d", tiles) + fitW := int(math.Round(float64(width) / fitScale)) + fitH := int(math.Round(float64(height) / fitScale)) + + // Step 2: scale so that shortest side is exactly 768 + minSide := math.Min(float64(fitW), float64(fitH)) + if minSide == 0 { + return baseTokens, nil + } + shortScale := 768.0 / minSide + finalW := int(math.Round(float64(fitW) * shortScale)) + finalH := int(math.Round(float64(fitH) * shortScale)) + + // Count 512px tiles + tilesW := (finalW + 512 - 1) / 512 + tilesH := (finalH + 512 - 1) / 512 + tiles := tilesW * tilesH + + if common.DebugEnabled { + log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles) + } + return tiles*tileTokens + baseTokens, nil } -func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { - tkm := 0 - msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) - if err != nil { - return 0, err +func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + if meta == nil { + return 0, errors.New("token count meta is nil") } - tkm += msgTokens - if request.Tools != nil { - openaiTools := request.Tools - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) - } - } - toolTokens := CountTokenInput(countStr, request.Model) - tkm += 8 - tkm += toolTokens + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + tkm := CountTextToken(meta.CombineText, model) + + if info.RelayFormat == types.RelayFormatOpenAI { + tkm += meta.ToolsCount * 8 + tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量 + tkm += meta.NameCount * 3 + tkm += 3 } + for _, file := range meta.Files { + switch file.FileType { + case types.FileTypeImage: + if info.RelayFormat == types.RelayFormatGemini { + tkm += 240 + } else { + token, err := getImageToken(file, model, info.IsStream) + if err != nil { + return 0, fmt.Errorf("error counting image token: %v", err) + } + tkm += token + } + case types.FileTypeAudio: + tkm += 100 + case types.FileTypeVideo: + tkm += 5000 + case types.FileTypeFile: + tkm += 5000 + } + } + + common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm) return tkm, nil } +//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { +// tkm := 0 +// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) +// if err != nil { +// return 0, err +// } +// tkm += msgTokens +// if request.Tools != nil { +// openaiTools := request.Tools +// countStr := "" +// for _, tool := range openaiTools { +// countStr = tool.Function.Name +// if tool.Function.Description != "" { +// countStr += tool.Function.Description +// } +// if tool.Function.Parameters != nil { +// countStr += fmt.Sprintf("%v", tool.Function.Parameters) +// } +// } +// toolTokens := CountTokenInput(countStr, request.Model) +// tkm += 8 +// tkm += toolTokens +// } +// +// return tkm, nil +//} + func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { tkm := 0 @@ -338,58 +469,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { - //recover when panic - tokenEncoder := getTokenEncoder(model) - // Reference: - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - // - // Every message follows <|start|>{role/name}\n{content}<|end|>\n - var tokensPerMessage int - var tokensPerName int - if model == "gpt-3.5-turbo-0301" { - tokensPerMessage = 4 - tokensPerName = -1 // If there's a name, the role is omitted - } else { - tokensPerMessage = 3 - tokensPerName = 1 - } - tokenNum := 0 - for _, message := range messages { - tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Content != nil { - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.GetImageMedia() - imageTokenNum, err := getImageToken(info, imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else if m.Type == dto.ContentTypeFile { - tokenNum += 5000 - } else if m.Type == dto.ContentTypeVideoUrl { - tokenNum += 5000 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) - } - } - } - } - tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil -} +//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { +// //recover when panic +// tokenEncoder := getTokenEncoder(model) +// // Reference: +// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +// // https://github.com/pkoukk/tiktoken-go/issues/6 +// // +// // Every message follows <|start|>{role/name}\n{content}<|end|>\n +// var tokensPerMessage int +// var tokensPerName int +// +// tokensPerMessage = 3 +// tokensPerName = 1 +// +// tokenNum := 0 +// for _, message := range messages { +// tokenNum += tokensPerMessage +// tokenNum += getTokenNum(tokenEncoder, message.Role) +// if message.Content != nil { +// if message.Name != nil { +// tokenNum += tokensPerName +// tokenNum += getTokenNum(tokenEncoder, *message.Name) +// } +// arrayContent := message.ParseContent() +// for _, m := range arrayContent { +// if m.Type == dto.ContentTypeImageURL { +// imageUrl := m.GetImageMedia() +// imageTokenNum, err := getImageToken(info, imageUrl, model, stream) +// if err != nil { +// return 0, err +// } +// tokenNum += imageTokenNum +// log.Printf("image token num: %d", imageTokenNum) +// } else if m.Type == dto.ContentTypeInputAudio { +// // TODO: 音频token数量计算 +// tokenNum += 100 +// } else if m.Type == dto.ContentTypeFile { +// tokenNum += 5000 +// } else if m.Type == dto.ContentTypeVideoUrl { +// tokenNum += 5000 +// } else { +// tokenNum += getTokenNum(tokenEncoder, m.Text) +// } +// } +// } +// } +// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> +// return tokenNum, nil +//} func CountTokenInput(input any, model string) int { switch v := input.(type) { diff --git a/service/user_notify.go b/service/user_notify.go index 96664007..1fcc62d3 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -4,6 +4,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "strings" ) @@ -12,7 +13,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -25,7 +26,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -37,14 +38,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // check setting email userEmail = userSetting.NotificationEmail if userEmail == "" { - common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) + logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } diff --git a/setting/chat.go b/setting/chat.go index b97d65ce..b417af28 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/common" + "one-api/logger" ) var Chats = []map[string]string{ @@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - common.SysError("error marshalling chats: " + err.Error()) + logger.SysError("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/config/config.go b/setting/config/config.go index 3af51b14..2e43e0a7 100644 --- a/setting/config/config.go +++ b/setting/config/config.go @@ -2,7 +2,7 @@ package config import ( "encoding/json" - "one-api/common" + "one-api/logger" "reflect" "strconv" "strings" @@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error { // 如果找到配置项,则更新配置 if len(configMap) > 0 { if err := updateConfigFromMap(config, configMap); err != nil { - common.SysError("failed to update config " + name + ": " + err.Error()) + logger.SysError("failed to update config " + name + ": " + err.Error()) continue } } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index d550b2c3..dcb9fae5 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "math" - "one-api/common" + "one-api/logger" "sync" ) @@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 3f223bc3..47079850 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -2,7 +2,7 @@ package ratio_setting import ( "encoding/json" - "one-api/common" + "one-api/logger" "sync" ) @@ -89,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - common.SysError("error marshalling cache ratio: " + err.Error()) + logger.SysError("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index 86f4a8d1..c1a666e9 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -3,7 +3,7 @@ package ratio_setting import ( "encoding/json" "errors" - "one-api/common" + "one-api/logger" "sync" ) @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - common.SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - common.SysError("error marshalling group-group ratio: " + err.Error()) + logger.SysError("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 4a19895e..ce822800 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -320,7 +320,7 @@ func ModelPrice2JSONString() string { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelPriceMap) + jsonBytes, err := common.Marshal(modelPriceMap) if err != nil { common.SysError("error marshalling model price: " + err.Error()) } @@ -359,7 +359,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := common.Unmarshal([]byte(jsonStr), &modelRatioMap) if err == nil { InvalidateExposedDataCache() } @@ -388,7 +388,7 @@ func GetModelRatio(name string) (float64, bool, string) { } func DefaultModelRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultModelRatio) + jsonBytes, err := common.Marshal(defaultModelRatio) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -420,7 +420,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := common.Unmarshal([]byte(jsonStr), &CompletionRatio) if err == nil { InvalidateExposedDataCache() } @@ -594,7 +594,7 @@ func ModelRatio2JSONString() string { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelRatioMap) + jsonBytes, err := common.Marshal(modelRatioMap) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -610,7 +610,7 @@ var imageRatioMapMutex sync.RWMutex func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() defer imageRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(imageRatioMap) + jsonBytes, err := common.Marshal(imageRatioMap) if err != nil { common.SysError("error marshalling cache ratio: " + err.Error()) } @@ -621,7 +621,7 @@ func UpdateImageRatioByJSONString(jsonStr string) error { imageRatioMapMutex.Lock() defer imageRatioMapMutex.Unlock() imageRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &imageRatioMap) + return common.Unmarshal([]byte(jsonStr), &imageRatioMap) } func GetImageRatio(name string) (float64, bool) { diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 0ae132d0..bcbe712c 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/common" + "one-api/logger" "sync" ) @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - common.SysError("error marshalling user groups: " + err.Error()) + logger.SysError("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/error.go b/types/error.go index 5a143612..2cfeb541 100644 --- a/types/error.go +++ b/types/error.go @@ -39,12 +39,13 @@ const ( ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" // new api error - ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" - ErrorCodeModelPriceError ErrorCode = "model_price_error" - ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" - ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" - ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" - ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" + ErrorCodeModelPriceError ErrorCode = "model_price_error" + ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" + ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" + ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" + ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" // channel error ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" diff --git a/types/price_data.go b/types/price_data.go new file mode 100644 index 00000000..f6a92d7e --- /dev/null +++ b/types/price_data.go @@ -0,0 +1,31 @@ +package types + +import "fmt" + +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 + HasSpecialRatio bool +} + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + CompletionRatio float64 + CacheRatio float64 + CacheCreationRatio float64 + ImageRatio float64 + UsePrice bool + ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo +} + +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +func (p PriceData) ToSetting() string { + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} diff --git a/types/relay_format.go b/types/relay_format.go new file mode 100644 index 00000000..4c29d649 --- /dev/null +++ b/types/relay_format.go @@ -0,0 +1,15 @@ +package types + +type RelayFormat string + +const ( + RelayFormatOpenAI RelayFormat = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatOpenAIRealtime = "openai_realtime" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" +) diff --git a/types/relay_request.go b/types/relay_request.go new file mode 100644 index 00000000..b9d092f0 --- /dev/null +++ b/types/relay_request.go @@ -0,0 +1,27 @@ +package types + +type RelayRequest struct { + OriginRequest any + Format RelayFormat + PromptTokenCount int +} + +func (r *RelayRequest) CopyOriginRequest() any { + if r.OriginRequest == nil { + return nil + } + switch v := r.OriginRequest.(type) { + case *GeneralOpenAIRequest: + return v.Copy() + case *GeneralClaudeRequest: + return v.Copy() + case *GeneralGeminiRequest: + return v.Copy() + case *GeneralRerankRequest: + return v.Copy() + case *GeneralEmbeddingRequest: + return v.Copy() + default: + return nil + } +} diff --git a/types/request_meta.go b/types/request_meta.go new file mode 100644 index 00000000..427bacb9 --- /dev/null +++ b/types/request_meta.go @@ -0,0 +1,45 @@ +package types + +type FileType string + +const ( + FileTypeImage FileType = "image" // Image file type + FileTypeAudio FileType = "audio" // Audio file type + FileTypeVideo FileType = "video" // Video file type + FileTypeFile FileType = "file" // Generic file type +) + +type TokenType string + +const ( + TokenTypeTextNumber TokenType = "text_number" // Text or number tokens + TokenTypeTokenizer TokenType = "tokenizer" // Tokenizer tokens + TokenTypeImage TokenType = "image" // Image tokens +) + +type TokenCountMeta struct { + TokenType TokenType `json:"token_type,omitempty"` // Type of tokens used in the request + CombineText string `json:"combine_text,omitempty"` // Combined text from all messages + ToolsCount int `json:"tools_count,omitempty"` // Number of tools used + NameCount int `json:"name_count,omitempty"` // Number of names in the request + MessagesCount int `json:"messages_count,omitempty"` // Number of messages in the request + Files []*FileMeta `json:"files,omitempty"` // List of files, each with type and content + MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens allowed in the request + + ImagePriceRatio float64 `json:"image_ratio,omitempty"` // Ratio for image size, if applicable + //IsStreaming bool `json:"is_streaming,omitempty"` // Indicates if the request is streaming +} + +type FileMeta struct { + FileType + MimeType string + Data string + Detail string +} + +type RequestMeta struct { + OriginalModelName string `json:"original_model_name"` + UserUsingGroup string `json:"user_using_group"` + PromptTokens int `json:"prompt_tokens"` + PreConsumedQuota int `json:"pre_consumed_quota"` +} From baf086d5b335ffec3524ce390b4fdc31f9b14376 Mon Sep 17 00:00:00 2001 From: fatcat-ww Date: Thu, 14 Aug 2025 20:38:50 +0800 Subject: [PATCH 02/12] Add files via upload --- web/src/components/layout/HeaderBar.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/src/components/layout/HeaderBar.js b/web/src/components/layout/HeaderBar.js index 4d83d48b..0570a0fb 100644 --- a/web/src/components/layout/HeaderBar.js +++ b/web/src/components/layout/HeaderBar.js @@ -628,7 +628,8 @@ const HeaderBar = ({ onMobileMenuToggle, drawerOpen }) => {
{ ); }; -export default HeaderBar; +export default HeaderBar; \ No newline at end of file From 6748b006b785f2baf63391bd35a32a50add115f3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 14 Aug 2025 21:10:04 +0800 Subject: [PATCH 03/12] refactor: centralize logging and update resource initialization This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase. --- common/limiter/limiter.go | 4 +- controller/channel-billing.go | 5 +- controller/channel-test.go | 27 ++++-- controller/console_migrate.go | 7 +- controller/github.go | 5 +- controller/model.go | 4 +- controller/oidc.go | 11 ++- controller/playground.go | 2 +- controller/relay.go | 62 ++++++------- controller/task.go | 18 ++-- controller/task_video.go | 7 +- controller/token.go | 3 +- controller/twofa.go | 13 ++- controller/user.go | 6 +- dto/openai_request.go | 6 +- dto/request_common.go | 13 +++ main.go | 34 +++---- middleware/recover.go | 6 +- middleware/turnstile-check.go | 5 +- model/ability.go | 9 +- model/channel.go | 27 +++--- model/channel_cache.go | 5 +- model/log.go | 2 +- model/main.go | 31 ++++--- model/option.go | 5 +- model/pricing.go | 3 +- model/token.go | 19 ++-- model/twofa.go | 9 +- model/usedata.go | 9 +- model/user.go | 24 ++--- model/user_cache.go | 5 +- model/utils.go | 9 +- relay/channel/ali/adaptor.go | 16 ++-- relay/channel/ali/image.go | 9 +- relay/channel/ali/text.go | 5 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/baidu/relay-baidu.go | 5 +- relay/channel/baidu_v2/adaptor.go | 10 +-- relay/channel/claude/adaptor.go | 4 +- relay/channel/claude/relay-claude.go | 21 +++-- relay/channel/cloudflare/adaptor.go | 8 +- relay/channel/cohere/adaptor.go | 4 +- relay/channel/cohere/relay-cohere.go | 5 +- relay/channel/coze/adaptor.go | 2 +- relay/channel/coze/relay-coze.go | 15 ++-- relay/channel/deepseek/adaptor.go | 6 +- relay/channel/dify/adaptor.go | 6 +- relay/channel/dify/relay-dify.go | 25 +++--- relay/channel/gemini/adaptor.go | 6 +- relay/channel/gemini/relay-gemini.go | 8 +- relay/channel/jimeng/adaptor.go | 2 +- relay/channel/jina/adaptor.go | 4 +- relay/channel/minimax/relay-minimax.go | 2 +- relay/channel/mistral/adaptor.go | 2 +- relay/channel/mokaai/adaptor.go | 2 +- relay/channel/moonshot/adaptor.go | 18 ++-- relay/channel/ollama/adaptor.go | 8 +- relay/channel/openai/adaptor.go | 28 +++--- relay/channel/openai/helper.go | 25 +++--- relay/channel/openai/relay-openai.go | 10 +-- relay/channel/palm/adaptor.go | 2 +- relay/channel/palm/relay-palm.go | 7 +- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/siliconflow/adaptor.go | 10 +-- relay/channel/task/jimeng/adaptor.go | 2 +- relay/channel/task/kling/adaptor.go | 2 +- relay/channel/task/suno/adaptor.go | 5 +- relay/channel/task/vidu/adaptor.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/tencent/relay-tencent.go | 7 +- relay/channel/volcengine/adaptor.go | 12 +-- relay/channel/xai/adaptor.go | 4 +- relay/channel/xai/text.go | 5 +- relay/channel/xunfei/relay-xunfei.go | 11 ++- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu/relay-zhipu.go | 9 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/common/relay_info.go | 36 ++++++-- relay/helper/price.go | 46 +++++----- relay/helper/valid_request.go | 2 +- relay/{chat_handler.go => mjproxy_handler.go} | 90 ++++++++----------- relay/relay-text.go | 20 +++++ relay/relay_task.go | 12 +-- relay/websocket.go | 44 +++------ router/main.go | 6 +- service/cf_worker.go | 6 +- service/error.go | 5 +- service/image.go | 10 +-- service/log_info_generate.go | 4 +- service/midjourney.go | 3 +- service/pre_consume_quota.go | 3 +- service/token_counter.go | 12 ++- service/user_notify.go | 9 +- setting/chat.go | 4 +- setting/config/config.go | 4 +- setting/rate_limit.go | 4 +- setting/ratio_setting/cache_ratio.go | 4 +- setting/ratio_setting/group_ratio.go | 8 +- setting/user_usable_group.go | 4 +- types/relay_format.go | 3 + types/relay_request.go | 27 ------ 101 files changed, 537 insertions(+), 568 deletions(-) rename relay/{chat_handler.go => mjproxy_handler.go} (87%) delete mode 100644 types/relay_request.go diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index fcfcb0c3..ef5d1935 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -5,7 +5,7 @@ import ( _ "embed" "fmt" "github.com/go-redis/redis/v8" - "one-api/logger" + "one-api/common" "sync" ) @@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter { // 预加载脚本 limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { - logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) + common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } instance = &RedisLimiter{ client: r, diff --git a/controller/channel-billing.go b/controller/channel-billing.go index bbf0f97a..5152e060 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,7 +8,6 @@ import ( "net/http" "one-api/common" "one-api/constant" - "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -486,8 +485,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - logger.SysLog("updating all channels") + common.SysLog("updating all channels") _ = updateAllChannelsBalance() - logger.SysLog("channels update done") + common.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index ec2e6226..32486a8b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -13,7 +13,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" @@ -133,8 +132,17 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } + request := buildTestRequest(testModel) - info := relaycommon.GenRelayInfo(c) + info, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, request, nil) + + if err != nil { + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), + } + } err = helper.ModelMappedHelper(c, info, nil) if err != nil { @@ -144,7 +152,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), } } + testModel = info.UpstreamModelName + request.Model = testModel apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -156,13 +166,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - request := buildTestRequest(testModel) // 创建一个用于日志的 info 副本,移除 ApiKey logInfo := *info logInfo.ApiKey = "" - logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) + priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) if err != nil { return testResult{ context: c, @@ -280,7 +289,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Group: info.UsingGroup, Other: other, }) - logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return testResult{ context: c, localErr: nil, @@ -462,13 +471,13 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { if frequency <= 0 { - logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") + common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") return } for { time.Sleep(time.Duration(frequency) * time.Minute) - logger.SysLog("testing all channels") + common.SysLog("testing all channels") _ = testAllChannels(false) - logger.SysLog("channel test finished") + common.SysLog("channel test finished") } } diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d21f5e21..f0812c3d 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -4,10 +4,11 @@ package controller import ( "encoding/json" - "github.com/gin-gonic/gin" "net/http" - "one-api/logger" + "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* @@ -98,6 +99,6 @@ func MigrateConsoleSetting(c *gin.Context) { // 重新加载 OptionMap model.InitOptionMap() - logger.SysLog("console setting migrated") + common.SysLog("console setting migrated") c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) } diff --git a/controller/github.go b/controller/github.go index 0715a8fe..881d6dc1 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" "time" @@ -48,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -64,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/model.go b/controller/model.go index d03fdeb2..398503e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -93,7 +93,9 @@ func init() { if !success || apiType == constant.APITypeAIProxyLibrary { continue } - meta := &relaycommon.RelayInfo{ChannelType: i} + meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{ + ChannelType: i, + }} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() diff --git a/controller/oidc.go b/controller/oidc.go index 1e3435a8..f3def0e3 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "one-api/common" - "one-api/logger" "one-api/model" "one-api/setting" "one-api/setting/system_setting" @@ -59,7 +58,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } res, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res.Body.Close() @@ -70,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - logger.SysError("OIDC 获取 Token 失败,请检查设置!") + common.SysLog("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -81,12 +80,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) res2, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - logger.SysError("OIDC 获取用户信息失败!请检查设置!") + common.SysLog("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -96,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - logger.SysError("OIDC 获取用户信息为空!请检查设置!") + common.SysLog("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/playground.go b/controller/playground.go index dd930802..8a1cb2b6 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -56,5 +56,5 @@ func Playground(c *gin.Context) { //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) - Relay(c) + Relay(c, types.RelayFormatOpenAI) } diff --git a/controller/relay.go b/controller/relay.go index 583ac036..8b67fd89 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -104,26 +104,6 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } - //includeUsage := true - //// 判断用户是否需要返回使用情况 - //if textRequest.StreamOptions != nil { - // includeUsage = textRequest.StreamOptions.IncludeUsage - //} - // - //// 如果不支持StreamOptions,将StreamOptions设置为nil - //if !relayInfo.SupportStreamOptions || !textRequest.Stream { - // textRequest.StreamOptions = nil - //} else { - // // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions - // if constant.ForceStreamOption { - // textRequest.StreamOptions = &dto.StreamOptions{ - // IncludeUsage: true, - // } - // } - //} - // - //relayInfo.ShouldIncludeUsage = includeUsage - relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) if err != nil { newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) @@ -178,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { switch relayFormat { case types.RelayFormatOpenAIRealtime: - newAPIError = relay.WssHelper(c, ws) + newAPIError = relay.WssHelper(c, relayInfo) case types.RelayFormatClaude: newAPIError = relay.ClaudeHelper(c, relayInfo) case types.RelayFormatGemini: @@ -324,35 +304,45 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t } func RelayMidjourney(c *gin.Context) { - relayMode := c.GetInt("relay_mode") - var err *dto.MidjourneyResponse - switch relayMode { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil) + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()), + "type": "upstream_error", + "code": 4, + }) + return + } + + var mjErr *dto.MidjourneyResponse + switch relayInfo.RelayMode { case relayconstant.RelayModeMidjourneyNotify: - err = relay.RelayMidjourneyNotify(c) + mjErr = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: - err = relay.RelayMidjourneyTask(c, relayMode) + mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: - err = relay.RelayMidjourneyTaskImageSeed(c) + mjErr = relay.RelayMidjourneyTaskImageSeed(c) case relayconstant.RelayModeSwapFace: - err = relay.RelaySwapFace(c) + mjErr = relay.RelaySwapFace(c, relayInfo) default: - err = relay.RelayMidjourneySubmit(c, relayMode) + mjErr = relay.RelayMidjourneySubmit(c, relayInfo) } //err = relayMidjourneySubmit(c, relayMode) - log.Println(err) - if err != nil { + log.Println(mjErr) + if mjErr != nil { statusCode := http.StatusBadRequest - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + if mjErr.Code == 30 { + mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" statusCode = http.StatusTooManyRequests } c.JSON(statusCode, gin.H{ - "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result), "type": "upstream_error", - "code": err.Code, + "code": mjErr.Code, }) channelId := c.GetInt("channel_id") - logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result))) } } diff --git a/controller/task.go b/controller/task.go index a5b28ae2..1082d7a1 100644 --- a/controller/task.go +++ b/controller/task.go @@ -26,7 +26,7 @@ func UpdateTaskBulk() { //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) - logger.SysLog("任务进度轮询开始") + common.SysLog("任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) @@ -66,7 +66,7 @@ func UpdateTaskBulk() { UpdateTaskByPlatform(platform, taskChannelM, taskM) } - logger.SysLog("任务进度轮询完成") + common.SysLog("任务进度轮询完成") } } @@ -78,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) default: if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) + common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) } } } @@ -100,14 +100,14 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas } channel, err := model.CacheGetChannel(channelId) if err != nil { - logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) err = model.TaskBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -119,7 +119,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { @@ -129,7 +129,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] @@ -139,7 +139,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas return err } if !responseItems.IsSuccess() { - logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) return err } @@ -179,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - logger.SysError("UpdateMidjourneyTask task error: " + err.Error()) + common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index dca42955..ffb6728b 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/logger" @@ -37,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -112,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.StartTime = now } case model.TaskStatusSuccess: - task.Progress = "100%" + task.Progress = "100%" if task.FinishTime == 0 { task.FinishTime = now } @@ -140,7 +141,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.Progress = taskResult.Progress } if err := task.Update(); err != nil { - logger.SysError("UpdateVideoTask task error: " + err.Error()) + common.SysLog("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index db575fec..399ccb4f 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,7 +3,6 @@ package controller import ( "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" @@ -103,7 +102,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - logger.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/twofa.go b/controller/twofa.go index 0ab66029..1859a128 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" @@ -71,7 +70,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成2FA密钥失败", }) - logger.SysError("生成TOTP密钥失败: " + err.Error()) + common.SysLog("生成TOTP密钥失败: " + err.Error()) return } @@ -82,7 +81,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - logger.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -116,7 +115,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - logger.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } @@ -295,7 +294,7 @@ func Get2FAStatus(c *gin.Context) { // 获取剩余备用码数量 backupCount, err := model.GetUnusedBackupCodeCount(userId) if err != nil { - logger.SysError("获取备用码数量失败: " + err.Error()) + common.SysLog("获取备用码数量失败: " + err.Error()) } else { status["backup_codes_remaining"] = backupCount } @@ -369,7 +368,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - logger.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -379,7 +378,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - logger.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } diff --git a/controller/user.go b/controller/user.go index 8ce44fa6..a7d59f17 100644 --- a/controller/user.go +++ b/controller/user.go @@ -193,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -236,7 +236,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - logger.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -343,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - logger.SysError("failed to generate key: " + err.Error()) + common.SysLog("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) diff --git a/dto/openai_request.go b/dto/openai_request.go index 0c01c503..12aa54f4 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -332,9 +332,9 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { } type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` - //MimeType string + Url string `json:"url"` + Detail string `json:"detail"` + MimeType string } func (m *MessageImageUrl) IsRemoteImage() bool { diff --git a/dto/request_common.go b/dto/request_common.go index e5dde8b5..8bd25785 100644 --- a/dto/request_common.go +++ b/dto/request_common.go @@ -9,3 +9,16 @@ type Request interface { GetTokenCountMeta() *types.TokenCountMeta IsStream(c *gin.Context) bool } + +type BaseRequest struct { +} + +func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta { + return &types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + } +} + +func (b *BaseRequest) IsStream(c *gin.Context) bool { + return false +} diff --git a/main.go b/main.go index 9a5bd652..2dfddacc 100644 --- a/main.go +++ b/main.go @@ -36,22 +36,22 @@ func main() { err := InitResources() if err != nil { - logger.FatalLog("failed to initialize resources: " + err.Error()) + common.FatalLog("failed to initialize resources: " + err.Error()) return } - logger.SysLog("New API " + common.Version + " started") + common.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } if common.DebugEnabled { - logger.SysLog("running in debug mode") + common.SysLog("running in debug mode") } defer func() { err := model.CloseDB() if err != nil { - logger.FatalLog("failed to close database: " + err.Error()) + common.FatalLog("failed to close database: " + err.Error()) } }() @@ -60,18 +60,18 @@ func main() { common.MemoryCacheEnabled = true } if common.MemoryCacheEnabled { - logger.SysLog("memory cache enabled") - logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + common.SysLog("memory cache enabled") + common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { - logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() @@ -90,14 +90,14 @@ func main() { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } @@ -111,7 +111,7 @@ func main() { } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true - logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } @@ -120,13 +120,13 @@ func main() { log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) }) go common.Monitor() - logger.SysLog("pprof enabled") + common.SysLog("pprof enabled") } // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -156,7 +156,7 @@ func main() { } err = server.Run(":" + port) if err != nil { - logger.FatalLog("failed to start HTTP server: " + err.Error()) + common.FatalLog("failed to start HTTP server: " + err.Error()) } } @@ -165,8 +165,8 @@ func InitResources() error { // This is a placeholder function for future resource initialization err := godotenv.Load(".env") if err != nil { - logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") - logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } // 加载环境变量 @@ -184,7 +184,7 @@ func InitResources() error { // Initialize SQL Database err = model.InitDB() if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) + common.FatalLog("failed to initialize database: " + err.Error()) return err } diff --git a/middleware/recover.go b/middleware/recover.go index 6c9c7ef6..d78c8137 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/logger" + "one-api/common" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) - logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index a136a900..106a7278 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "one-api/common" - "one-api/logger" ) type turnstileCheckResponse struct { @@ -38,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -50,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/model/ability.go b/model/ability.go index ac5530d8..123fc7be 100644 --- a/model/ability.go +++ b/model/ability.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "strings" "sync" @@ -295,13 +294,13 @@ func FixAbility() (int, int, error) { if common.UsingSQLite { err := DB.Exec("DELETE FROM abilities").Error if err != nil { - logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) return 0, 0, err } } else { err := DB.Exec("TRUNCATE TABLE abilities").Error if err != nil { - logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) return 0, 0, err } } @@ -321,7 +320,7 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } @@ -329,7 +328,7 @@ func FixAbility() (int, int, error) { for _, channel := range chunk { err = channel.AddAbilities(nil) if err != nil { - logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index c0d253fc..af769f63 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/types" "strings" "sync" @@ -210,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - logger.SysError("failed to unmarshal other info: " + err.Error()) + common.SysLog("failed to unmarshal other info: " + err.Error()) } } return otherInfo @@ -219,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - logger.SysError("failed to marshal other info: " + err.Error()) + common.SysLog("failed to marshal other info: " + err.Error()) return } channel.OtherInfo = string(otherInfoBytes) @@ -489,7 +488,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - logger.SysError("failed to update response time: " + err.Error()) + common.SysLog("failed to update response time: " + err.Error()) } } @@ -499,7 +498,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - logger.SysError("failed to update balance: " + err.Error()) + common.SysLog("failed to update balance: " + err.Error()) } } @@ -615,7 +614,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - logger.SysError("failed to update ability status: " + err.Error()) + common.SysLog("failed to update ability status: " + err.Error()) } } }() @@ -643,7 +642,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - logger.SysError("failed to update channel status: " + err.Error()) + common.SysLog("failed to update channel status: " + err.Error()) return false } } @@ -705,7 +704,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - logger.SysError("failed to update abilities: " + err.Error()) + common.SysLog("failed to update abilities: " + err.Error()) } } } @@ -729,7 +728,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - logger.SysError("failed to update channel used quota: " + err.Error()) + common.SysLog("failed to update channel used quota: " + err.Error()) } } @@ -822,7 +821,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -833,7 +832,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -844,7 +843,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -855,7 +854,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } channel.OtherSettings = string(settingBytes) @@ -866,7 +865,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - logger.SysError("failed to unmarshal param override: " + err.Error()) + common.SysLog("failed to unmarshal param override: " + err.Error()) } } return paramOverride diff --git a/model/channel_cache.go b/model/channel_cache.go index 22216027..86866e40 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -6,7 +6,6 @@ import ( "math/rand" "one-api/common" "one-api/constant" - "one-api/logger" "one-api/setting" "one-api/setting/ratio_setting" "sort" @@ -85,13 +84,13 @@ func InitChannelCache() { } channelsIDM = newChannelId2channel channelSyncLock.Unlock() - logger.SysLog("channels synced from database") + common.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - logger.SysLog("syncing channels from database") + common.SysLog("syncing channels from database") InitChannelCache() } } diff --git a/model/log.go b/model/log.go index d9495968..e443516d 100644 --- a/model/log.go +++ b/model/log.go @@ -88,7 +88,7 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - logger.SysError("failed to record log: " + err.Error()) + common.SysLog("failed to record log: " + err.Error()) } } diff --git a/model/main.go b/model/main.go index 1e582e1a..dbf27152 100644 --- a/model/main.go +++ b/model/main.go @@ -5,7 +5,6 @@ import ( "log" "one-api/common" "one-api/constant" - "one-api/logger" "os" "strings" "sync" @@ -85,7 +84,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") + common.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -109,7 +108,7 @@ func CheckSetup() { if setup == nil { // No setup record exists, check if we have a root user if RootUserExists() { - logger.SysLog("system is not initialized, but root user exists") + common.SysLog("system is not initialized, but root user exists") // Create setup record newSetup := Setup{ Version: common.Version, @@ -117,16 +116,16 @@ func CheckSetup() { } err := DB.Create(&newSetup).Error if err != nil { - logger.SysLog("failed to create setup record: " + err.Error()) + common.SysLog("failed to create setup record: " + err.Error()) } constant.Setup = true } else { - logger.SysLog("system is not initialized and no root user exists") + common.SysLog("system is not initialized and no root user exists") constant.Setup = false } } else { // Setup record exists, system is initialized - logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) + common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) constant.Setup = true } } @@ -139,7 +138,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { if dsn != "" { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") + common.SysLog("using PostgreSQL as database") if !isLog { common.UsingPostgreSQL = true } else { @@ -153,7 +152,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } if strings.HasPrefix(dsn, "local") { - logger.SysLog("SQL_DSN not set, using SQLite as database") + common.SysLog("SQL_DSN not set, using SQLite as database") if !isLog { common.UsingSQLite = true } else { @@ -164,7 +163,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use MySQL - logger.SysLog("using MySQL as database") + common.SysLog("using MySQL as database") // check parseTime if !strings.Contains(dsn, "parseTime") { if strings.Contains(dsn, "?") { @@ -183,7 +182,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") + common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL @@ -217,11 +216,11 @@ func InitDB() (err error) { if common.UsingMySQL { //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded } - logger.SysLog("database migration started") + common.SysLog("database migration started") err = migrateDB() return err } else { - logger.FatalLog(err) + common.FatalLog(err) } return err } @@ -254,11 +253,11 @@ func InitLogDB() (err error) { if !common.IsMasterNode { return nil } - logger.SysLog("database migration started") + common.SysLog("database migration started") err = migrateLOGDB() return err } else { - logger.FatalLog(err) + common.FatalLog(err) } return err } @@ -355,7 +354,7 @@ func migrateDBFast() error { return err } } - logger.SysLog("database migrated") + common.SysLog("database migrated") return nil } @@ -504,6 +503,6 @@ func PingDB() error { } lastPingTime = time.Now() - logger.SysLog("Database pinged successfully") + common.SysLog("Database pinged successfully") return nil } diff --git a/model/option.go b/model/option.go index 8fcd13a8..2121710c 100644 --- a/model/option.go +++ b/model/option.go @@ -2,7 +2,6 @@ package model import ( "one-api/common" - "one-api/logger" "one-api/setting" "one-api/setting/config" "one-api/setting/operation_setting" @@ -151,7 +150,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - logger.SysError("failed to update option map: " + err.Error()) + common.SysLog("failed to update option map: " + err.Error()) } } } @@ -159,7 +158,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - logger.SysLog("syncing options from database") + common.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/pricing.go b/model/pricing.go index 31aa5cdf..3c9349de 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -3,7 +3,6 @@ package model import ( "encoding/json" "fmt" - "one-api/logger" "strings" "one-api/common" @@ -93,7 +92,7 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 diff --git a/model/token.go b/model/token.go index 63c17e2d..320b5cf0 100644 --- a/model/token.go +++ b/model/token.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -92,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - logger.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -103,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - logger.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -135,7 +134,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -148,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -179,7 +178,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - logger.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -195,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - logger.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -210,7 +209,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - logger.SysError("failed to delete token cache: " + err.Error()) + common.SysLog("failed to delete token cache: " + err.Error()) } }) } @@ -270,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - logger.SysError("failed to increase token quota: " + err.Error()) + common.SysLog("failed to increase token quota: " + err.Error()) } }) } @@ -300,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - logger.SysError("failed to decrease token quota: " + err.Error()) + common.SysLog("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/twofa.go b/model/twofa.go index b2ea54e0..8e97289f 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "time" "gorm.io/gorm" @@ -244,7 +243,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { if !common.ValidateTOTPCode(t.Secret, code) { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - logger.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -256,7 +255,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - logger.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil @@ -278,7 +277,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { if !valid { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - logger.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -290,7 +289,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - logger.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil diff --git a/model/usedata.go b/model/usedata.go index f0027a8d..1255b0be 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -4,7 +4,6 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" - "one-api/logger" "sync" "time" ) @@ -25,12 +24,12 @@ func UpdateQuotaData() { // recover defer func() { if r := recover(); r != nil { - logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) + common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) } }() for { if common.DataExportEnabled { - logger.SysLog("正在更新数据看板数据...") + common.SysLog("正在更新数据看板数据...") SaveQuotaDataCache() } time.Sleep(time.Duration(common.DataExportInterval) * time.Minute) @@ -92,7 +91,7 @@ func SaveQuotaDataCache() { } } CacheQuotaData = make(map[string]*QuotaData) - logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) + common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) } func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) { @@ -103,7 +102,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int, "token_used": gorm.Expr("token_used + ?", tokenUsed), }).Error if err != nil { - logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) + common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) } } diff --git a/model/user.go b/model/user.go index 244380ad..29d7a446 100644 --- a/model/user.go +++ b/model/user.go @@ -76,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -85,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) @@ -518,7 +518,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - logger.SysError("no such user " + err.Error()) + common.SysLog("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -573,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - logger.SysError("failed to update user quota cache: " + err.Error()) + common.SysLog("failed to update user quota cache: " + err.Error()) } }) } @@ -611,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - logger.SysError("failed to update user group cache: " + err.Error()) + common.SysLog("failed to update user group cache: " + err.Error()) } }) } @@ -640,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - logger.SysError("failed to update user setting cache: " + err.Error()) + common.SysLog("failed to update user setting cache: " + err.Error()) } }) } @@ -670,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - logger.SysError("failed to increase user quota: " + err.Error()) + common.SysLog("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -695,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - logger.SysError("failed to decrease user quota: " + err.Error()) + common.SysLog("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -751,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - logger.SysError("failed to update user used quota and request count: " + err.Error()) + common.SysLog("failed to update user used quota and request count: " + err.Error()) return } @@ -768,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - logger.SysError("failed to update user used quota: " + err.Error()) + common.SysLog("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - logger.SysError("failed to update user request count: " + err.Error()) + common.SysLog("failed to update user request count: " + err.Error()) } } @@ -786,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - logger.SysError("failed to update user name cache: " + err.Error()) + common.SysLog("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index dec7597b..936e1a43 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,7 +5,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "time" "github.com/gin-gonic/gin" @@ -38,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -79,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index abd96b79..dced2bc6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -3,7 +3,6 @@ package model import ( "errors" "one-api/common" - "one-api/logger" "sync" "time" @@ -66,7 +65,7 @@ func batchUpdate() { return } - logger.SysLog("batch update started") + common.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -78,12 +77,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - logger.SysError("failed to batch update user quota: " + err.Error()) + common.SysLog("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - logger.SysError("failed to batch update token quota: " + err.Error()) + common.SysLog("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -94,7 +93,7 @@ func batchUpdate() { } } } - logger.SysLog("batch update finished") + common.SysLog("batch update finished") } func RecordExist(err error) (bool, error) { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfb94008..0ae8a8d1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) + case types.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl) default: switch info.RelayMode { case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl) case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl) } } @@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 841896cf..645882bc 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/logger" relaycommon "one-api/relay/common" @@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { imageRequest.Input.Prompt = request.Prompt imageRequest.Model = request.Model imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) - imageRequest.Parameters.N = request.N + imageRequest.Parameters.N = int(request.N) imageRequest.ResponseFormat = request.ResponseFormat return &imageRequest } func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { - url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) + url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID) var aliResponse AliResponse @@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - logger.SysError("updateTask client.Do err: " + err.Error()) + common.SysLog("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error var response AliResponse err = json.Unmarshal(responseBody, &response) if err != nil { - logger.SysError("updateTask NewDecoder err: " + err.Error()) + common.SysLog("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 17fcef2a..67b63286 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/relay/helper" "one-api/service" "strings" @@ -150,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -163,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 8396a844..32e301ee 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { default: suffix += strings.ToLower(info.UpstreamModelName) } - fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix) var accessToken string var err error if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 696c2496..31e8319e 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -119,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -130,7 +129,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - logger.SysError("error sending stream response: " + err.Error()) + common.SysLog("error sending stream response: " + err.Error()) } return true }) diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index ba59e307..6744f8ba 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 39b8ce2f..41583d30 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if a.RequestMode == RequestModeMessage { - return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 5d839908..57670bcf 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -376,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -610,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) if requestMode == RequestModeCompletion { @@ -629,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } helper.ClaudeChunkData(c, claudeResponse, data) - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { + } else if info.RelayFormat == types.RelayFormatOpenAI { response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { @@ -654,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - logger.SysError("claude response usage is not complete, maybe upstream error") + common.SysLog("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { // - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - + } else if info.RelayFormat == types.RelayFormatOpenAI { if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - logger.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } } helper.Done(c) @@ -722,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } var responseData []byte switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: responseData = data } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 4b9f5028..bdea72f0 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeResponses: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil default: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil } } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 887f9efd..c8a38d46 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index ccef9b23..af357348 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -119,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -154,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 658c6193..0f2a6fd3 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 18ed46af..c480045f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -9,7 +9,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -155,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -172,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -204,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } - logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { - requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") // 将 conversationId和chatId作为参数发送get请求 @@ -259,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { - requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") req, err := http.NewRequest("GET", requestURL, nil) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index be8de0c8..17d732ab 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fimBaseUrl := info.BaseUrl - if !strings.HasSuffix(info.BaseUrl, "/beta") { + fimBaseUrl := info.ChannelBaseUrl + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { fimBaseUrl += "/beta" } switch info.RelayMode { case constant.RelayModeCompletions: return fmt.Sprintf("%s/completions", fimBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 8c7898c9..0a08d035 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch a.BotType { case BotTypeWorkFlow: - return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil case BotTypeCompletion: - return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil case BotTypeAgent: fallthrough default: - return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index f03d61a4..2336fd4c 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -23,7 +22,7 @@ import ( ) func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { - uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) + uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl) switch media.Type { case dto.ContentTypeImageURL: // Decode base64 data @@ -37,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - logger.SysError("failed to decode base64: " + err.Error()) + common.SysLog("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - logger.SysError("failed to create temp file: " + err.Error()) + common.SysLog("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -52,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - logger.SysError("failed to write to temp file: " + err.Error()) + common.SysLog("failed to write to temp file: " + err.Error()) return nil } @@ -62,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - logger.SysError("failed to add user field: " + err.Error()) + common.SysLog("failed to add user field: " + err.Error()) return nil } @@ -75,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - logger.SysError("failed to create form file: " + err.Error()) + common.SysLog("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - logger.SysError("failed to copy file content: " + err.Error()) + common.SysLog("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -89,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - logger.SysError("failed to create request: " + err.Error()) + common.SysLog("failed to create request: " + err.Error()) return nil } @@ -100,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - logger.SysError("failed to send request: " + err.Error()) + common.SysLog("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -110,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.SysError("failed to decode response: " + err.Error()) + common.SysLog("failed to decode response: " + err.Error()) return nil } @@ -220,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -240,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 05d974f6..99b6645e 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { - return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil } if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || @@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsGeminiBatchEmbedding { action = "batchEmbedContents" } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" @@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { info.DisablePing = true } } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 82a2d8de..af5e8233 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -994,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := handleFinalStream(c, info, response) if err != nil { - logger.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } //if info.RelayFormat == relaycommon.RelayFormatOpenAI { // helper.Done(c) @@ -1042,19 +1042,19 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R fullTextResponse.Usage = usage switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: responseBody, err = common.Marshal(fullTextResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: break } diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go index ff9ac678..885a1427 100644 --- a/relay/channel/jimeng/adaptor.go +++ b/relay/channel/jimeng/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil + return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index bf318aa7..a383728f 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } return "", errors.New("invalid relay mode") } diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index d0a15b0d..ff9b72ea 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,5 +6,5 @@ import ( ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 45cb3290..f98ff869 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 37db2aec..f9da685f 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "m3e") { suffix = "embeddings" } - fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix) return fullRequestURL, nil } diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index d540388d..29004d0c 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } @@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: adaptor := openai.Adaptor{} return adaptor.DoResponse(c, resp, info) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 1f3fda8d..1a0caf75 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == relaycommon.RelayFormatClaude { - return info.BaseUrl + "/v1/chat/completions", nil + if info.RelayFormat == types.RelayFormatClaude { + return info.ChannelBaseUrl + "/v1/chat/completions", nil } switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - return info.BaseUrl + "/api/embed", nil + return info.ChannelBaseUrl + "/api/embed", nil default: - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index fc1749a0..d783b9d8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { - if strings.HasPrefix(info.BaseUrl, "https://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") + if strings.HasPrefix(info.ChannelBaseUrl, "https://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://") baseUrl = "wss://" + baseUrl - info.BaseUrl = baseUrl - } else if strings.HasPrefix(info.BaseUrl, "http://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "http://") + info.ChannelBaseUrl = baseUrl + } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://") baseUrl = "ws://" + baseUrl - info.BaseUrl = baseUrl + info.ChannelBaseUrl = baseUrl } } switch info.ChannelType { @@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { task = strings.TrimPrefix(task, "messages") task = "chat/completions" + task } @@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { responsesApiVersion := "preview" subUrl := "/openai/v1/responses" - if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") { + if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") { subUrl = "/openai/responses" responsesApiVersion = apiVersion } @@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil } model_ := info.UpstreamModelName @@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) case constant.ChannelTypeCustom: - url := info.BaseUrl + url := info.ChannelBaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil default: - if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini { + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 80973aa1..2a4b4938 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -12,6 +12,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -22,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string info.SendResponseCount++ switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: return handleClaudeFormat(c, data, info) - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: return handleGeminiFormat(c, data, info) } return nil @@ -111,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - logger.SysError("error processing stream response: " + err.Error()) + common.SysLog("error processing stream response: " + err.Error()) } } return nil @@ -147,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -202,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream usage *dto.Usage, containStreamUsage bool) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if info.ShouldIncludeUsage && !containStreamUsage { response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) @@ -210,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } helper.Done(c) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -225,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream _ = helper.ClaudeData(c, *resp) } - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -246,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - logger.SysError("error marshalling gemini response: " + err.Error()) + common.SysLog("error marshalling gemini response: " + err.Error()) return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 447e0f31..00dde46d 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -130,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - logger.SysError("error handling stream format: " + err.Error()) + common.SysLog("error handling stream format: " + err.Error()) } } if len(data) > 0 { @@ -147,7 +147,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } - if info.RelayFormat == relaycommon.RelayFormatOpenAI { + if info.RelayFormat == types.RelayFormatOpenAI { if shouldSendLastResp { _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) } @@ -211,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if forceFormat { responseBody, err = common.Marshal(simpleResponse) if err != nil { @@ -220,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { break } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) geminiRespStr, err := common.Marshal(geminiResp) if err != nil { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4d1ab783..2a022a1b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 1264b2b4..3a6ec2f4 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,7 +7,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -59,7 +58,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - logger.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) stopChan <- true return } @@ -67,7 +66,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -79,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) stopChan <- true return } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 92cb08a2..8ab9c854 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 05e6d453..4c176c08 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 8d057513..a5ada137 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -76,7 +76,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl // apiKey format: "access_key|secret_key" keyParts := strings.Split(info.ApiKey, "|") diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index b7b9a5ff..1fecda08 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -81,7 +81,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey // apiKey format: "access_key|secret_key" diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 1deb33fd..df2bb99e 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" @@ -60,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - baseURL := info.BaseUrl + baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } @@ -140,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - logger.SysError(fmt.Sprintf("Get Task error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index f40b480c..b0cc0bdc 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -86,7 +86,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index b86d8a16..ab96ecaa 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/", info.BaseUrl), nil + return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index d3aeab3f..f33a275c 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -13,7 +13,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -107,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) continue } @@ -118,12 +117,12 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } } if err := scanner.Err(); err != nil { - logger.SysError("error reading stream: " + err.Error()) + common.SysLog("error reading stream: " + err.Error()) } helper.Done(c) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 2cc4f663..b46cb952 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 6a3a5370..d5671ab2 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf xaiRequest := ImageRequest{ Model: request.Model, Prompt: request.Prompt, - N: request.N, + N: int(request.N), ResponseFormat: request.ResponseFormat, } return xaiRequest, nil @@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d4e7b92..5cae9c0a 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -6,7 +6,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -48,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } @@ -64,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 398bb08d..54ed476f 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/relay/helper" "one-api/types" "strings" @@ -144,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -219,20 +218,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - logger.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - logger.SysError("error closing websocket connection: " + err.Error()) + common.SysLog("error closing websocket connection: " + err.Error()) } break } @@ -283,6 +282,6 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } apiVersion = "v1.1" - logger.SysLog("api_version not found, using default: " + apiVersion) + common.SysLog("api_version not found, using default: " + apiVersion) return apiVersion } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index e3be0e8e..bd27c90b 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsStream { method = "sse-invoke" } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 65b662b6..8eb0dcc1 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -40,7 +39,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - logger.SysError("invalid zhipu key: " + apikey) + common.SysLog("invalid zhipu key: " + apikey) return "" } @@ -188,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -197,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index a83e30e6..0fae3767 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl) + baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl) switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return fmt.Sprintf("%s/embeddings", baseUrl), nil diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 59be0011..742cd61c 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -66,6 +66,7 @@ type ChannelMeta struct { ChannelOtherSettings dto.ChannelOtherSettings UpstreamModelName string IsModelMapped bool + SupportStreamOptions bool // 是否支持流式选项 } type RelayInfo struct { @@ -86,9 +87,9 @@ type RelayInfo struct { RelayMode int OriginModelName string //RecodeModelName string - RequestURLPath string - PromptTokens int - SupportStreamOptions bool + RequestURLPath string + PromptTokens int + //SupportStreamOptions bool ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -135,6 +136,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { ParamOverride: paramOverride, UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), IsModelMapped: false, + SupportStreamOptions: false, } channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) @@ -146,6 +148,10 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { if ok { channelMeta.ChannelOtherSettings = channelOtherSettings } + + if streamSupportedChannels[channelMeta.ChannelType] { + channelMeta.SupportStreamOptions = true + } info.ChannelMeta = channelMeta } @@ -268,6 +274,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { startTime = time.Now() } + isStream := false + + if request != nil { + isStream = request.IsStream(c) + } + // firstResponseTime = time.Now() - 1 second info := &RelayInfo{ @@ -289,7 +301,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RequestURLPath: c.Request.URL.String(), - IsStream: request.IsStream(c), + IsStream: isStream, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), @@ -339,6 +351,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return GenRelayInfoResponses(c, request), nil } return nil, errors.New("request is not a OpenAIResponsesRequest") + case types.RelayFormatTask: + return genBaseRelayInfo(c, nil), nil + case types.RelayFormatMjProxy: + return genBaseRelayInfo(c, nil), nil default: return nil, errors.New("invalid relay format") } @@ -367,11 +383,15 @@ type TaskRelayInfo struct { ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - info := &TaskRelayInfo{ - RelayInfo: GenRelayInfo(c), +func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) { + relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return nil, err } - return info + info := &TaskRelayInfo{ + RelayInfo: relayInfo, + } + return info, nil } type TaskSubmitReq struct { diff --git a/relay/helper/price.go b/relay/helper/price.go index 89fc3b66..fdc5b66d 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -53,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var imageRatio float64 var cacheCreationRatio float64 if !usePrice { - preConsumedTokens := common.PreConsumedQuota + preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) if meta.MaxTokens != 0 { - preConsumedTokens = promptTokens + meta.MaxTokens + preConsumedTokens += meta.MaxTokens } var success bool var matchName string @@ -102,27 +102,27 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { -// groupRatioInfo := HandleGroupRatio(c, info) -// -// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) -// // 如果没有配置价格,则使用默认价格 -// if !success { -// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] -// if !ok { -// modelPrice = 0.1 -// } else { -// modelPrice = defaultPrice -// } -// } -// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) -// priceData := types.PerCallPriceData{ -// ModelPrice: modelPrice, -// Quota: quota, -// GroupRatioInfo: groupRatioInfo, -// } -// return priceData -//} +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { + groupRatioInfo := HandleGroupRatio(c, info) + + modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) + // 如果没有配置价格,则使用默认价格 + if !success { + defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) + priceData := types.PerCallPriceData{ + ModelPrice: modelPrice, + Quota: quota, + GroupRatioInfo: groupRatioInfo, + } + return priceData +} func ContainPriceOrRatio(modelName string) bool { _, ok := ratio_setting.GetModelPrice(modelName, false) diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 0bc51774..1d556a33 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -36,7 +36,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt case types.RelayFormatOpenAIAudio: request, err = GetAndValidAudioRequest(c, relayMode) case types.RelayFormatOpenAIRealtime: - // nothing to do, no request body + request = &dto.BaseRequest{} default: return nil, fmt.Errorf("unsupported relay format: %s", format) } diff --git a/relay/chat_handler.go b/relay/mjproxy_handler.go similarity index 87% rename from relay/chat_handler.go rename to relay/mjproxy_handler.go index 30bce55c..756ad450 100644 --- a/relay/chat_handler.go +++ b/relay/mjproxy_handler.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -171,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { - startTime := time.Now().UnixNano() / int64(time.Millisecond) - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - //group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { var swapFaceRequest dto.SwapFaceRequest err := common.UnmarshalBodyReusable(c, &swapFaceRequest) if err != nil { @@ -188,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + priceData := helper.ModelPriceHelperPerCall(c, info) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -213,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) + err := service.PostConsumeQuota(info, priceData.Quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, - TokenId: tokenId, - UserQuota: userQuota, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) } }() midjResponse := &mjResp.Response midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: info.UserId, Code: midjResponse.Code, Action: constant.MjActionSwapFace, MjId: midjResponse.Result, @@ -246,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { PromptEn: "", Description: midjResponse.Description, State: "", - SubmitTime: startTime, + SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), StartTime: time.Now().UnixNano() / int64(time.Millisecond), FinishTime: 0, ImageUrl: "", @@ -370,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - - //tokenId := c.GetInt("token_id") - //channelType := c.GetInt("channel") - userId := c.GetInt("id") - group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { consumeQuota := true var midjRequest dto.MidjourneyRequest err := common.UnmarshalBodyReusable(c, &midjRequest) @@ -385,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") } - if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } - relayMode = relayconstant.RelayModeMidjourneyChange + relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange } - if relayMode == relayconstant.RelayModeMidjourneyVideo { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo } - if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = constant.MjActionImagine - } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = constant.MjActionDescribe - } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 midjRequest.Action = constant.MjActionEdits - } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten - } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend - } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" - if relayMode == relayconstant.RelayModeMidjourneyChange { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { @@ -423,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } //action = midjRequest.Action mjId = midjRequest.TaskId - } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } @@ -433,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } mjId = params.TaskId midjRequest.Action = params.Action - } else if relayMode == relayconstant.RelayModeMidjourneyModal { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { //if midjRequest.MaskBase64 == "" { // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") //} mjId = midjRequest.TaskId midjRequest.Action = constant.MjActionModal - } else if relayMode == relayconstant.RelayModeMidjourneyVideo { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") @@ -449,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = midjRequest.TaskId } - originTask := model.GetByMJId(userId, mjId) + originTask := model.GetByMJId(relayInfo.UserId, mjId) if originTask == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 if setting.MjActionCheckSuccessEnabled { - if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } } @@ -497,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons priceData := helper.ModelPriceHelperPerCall(c, relayInfo) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -522,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + ChannelId: relayInfo.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, - Group: group, + Group: relayInfo.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) } }() @@ -551,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} // other: 提交错误,description为错误描述 midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: relayInfo.UserId, Code: midjResponse.Code, Action: midjRequest.Action, MjId: midjResponse.Result, @@ -573,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - logger.SysError("get_channel_null: " + err.Error()) + common.SysLog("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/relay-text.go b/relay/relay-text.go index de750e76..5c07c718 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -44,6 +44,26 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } + includeUsage := true + // 判断用户是否需要返回使用情况 + if textRequest.StreamOptions != nil { + includeUsage = textRequest.StreamOptions.IncludeUsage + } + + // 如果不支持StreamOptions,将StreamOptions设置为nil + if !info.SupportStreamOptions || !textRequest.Stream { + textRequest.StreamOptions = nil + } else { + // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + if constant.ForceStreamOption { + textRequest.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + } + + info.ShouldIncludeUsage = includeUsage + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) diff --git a/relay/relay_task.go b/relay/relay_task.go index ae002d73..95b8083b 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -28,7 +27,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if platform == "" { platform = GetTaskPlatform(c) } - relayInfo := relaycommon.GenTaskRelayInfo(c) + + relayInfo, err := relaycommon.GenTaskRelayInfo(c) + if err != nil { + return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) + } adaptor := GetTaskAdaptor(platform) if adaptor == nil { @@ -98,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.BaseUrl = channel.GetBaseURL() + relayInfo.ChannelBaseUrl = channel.GetBaseURL() relayInfo.ChannelId = originTask.ChannelId } } @@ -128,7 +131,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") @@ -150,7 +153,6 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, Group: relayInfo.UsingGroup, Other: other, }) diff --git a/relay/websocket.go b/relay/websocket.go index 22b681f1..2d313154 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/service" "one-api/types" @@ -12,58 +11,35 @@ import ( "github.com/gorilla/websocket" ) -func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoWs(c, ws) +func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - err := helper.ModelMappedHelper(c, relayInfo, nil) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) //var requestBody io.Reader //firstWssRequest, _ := c.Get("first_wss_request") //requestBody = bytes.NewBuffer(firstWssRequest.([]byte)) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, nil) + resp, err := adaptor.DoRequest(c, info, nil) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { - relayInfo.TargetWs = resp.(*websocket.Conn) - defer relayInfo.TargetWs.Close() + info.TargetWs = resp.(*websocket.Conn) + defer info.TargetWs.Close() } - usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, nil, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, priceData, "") + service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "") return nil } diff --git a/router/main.go b/router/main.go index 7653f3a5..23576427 100644 --- a/router/main.go +++ b/router/main.go @@ -3,12 +3,12 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" - "one-api/logger" "os" "strings" + + "github.com/gin-gonic/gin" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { @@ -19,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - logger.SysLog("FRONTEND_BASE_URL is ignored on master node") + common.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS, indexPage) diff --git a/service/cf_worker.go b/service/cf_worker.go index 65f7f133..ae6e1ffe 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "net/http" - "one-api/logger" + "one-api/common" "one-api/setting" "strings" ) @@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { - logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) } } diff --git a/service/error.go b/service/error.go index 668731b0..ef5cbbde 100644 --- a/service/error.go +++ b/service/error.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/types" "strconv" "strings" @@ -59,7 +58,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError lowerText := strings.ToLower(text) if !strings.HasPrefix(lowerText, "get file base64 from url") { if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - logger.SysLog(fmt.Sprintf("error: %s", text)) + common.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } } @@ -139,7 +138,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { text := err.Error() lowerText := strings.ToLower(text) if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - logger.SysLog(fmt.Sprintf("error: %s", text)) + common.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } //避免暴露内部错误 diff --git a/service/image.go b/service/image.go index 957ca041..252093f1 100644 --- a/service/image.go +++ b/service/image.go @@ -8,8 +8,8 @@ import ( "image" "io" "net/http" + "one-api/common" "one-api/constant" - "one-api/logger" "strings" "golang.org/x/image/webp" @@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { response, err := DoDownloadRequest(imageUrl) if err != nil { - logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err } defer response.Body.Close() @@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { - logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) + common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) // 从response.Body读取更多的数据直到达到当前的限制 additionalData := make([]byte, limit-int64(len(readData))) @@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) { config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - logger.SysLog(err.Error()) + common.SysLog(err.Error()) config, err = webp.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - logger.SysLog(err.Error()) + common.SysLog(err.Error()) } format = "webp" } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 0dae9a03..7a609c9f 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -5,7 +5,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/midjourney.go b/service/midjourney.go index 1d232739..916d02d0 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relayconstant "one-api/relay/constant" "one-api/setting" "strconv" @@ -213,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - logger.SysError("do request failed: " + err.Error()) + common.SysLog("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3c4d0e7e..3902ef92 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -6,6 +6,7 @@ import ( "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "net/http" + "one-api/common" "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" @@ -19,7 +20,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) if err != nil { - logger.SysError("error return pre-consumed quota: " + err.Error()) + common.SysLog("error return pre-consumed quota: " + err.Error()) } }) } diff --git a/service/token_counter.go b/service/token_counter.go index ec817182..43a508c1 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/types" "strings" @@ -32,9 +31,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec) var tokenEncoderMutex sync.RWMutex func InitTokenEncoders() { - logger.SysLog("initializing token encoders") + common.SysLog("initializing token encoders") defaultTokenEncoder = codec.NewCl100kBase() - logger.SysLog("token encoders initialized") + common.SysLog("token encoders initialized") } func getTokenEncoder(model string) tokenizer.Codec { @@ -158,7 +157,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er if strings.HasPrefix(fileMeta.Data, "http") { config, format, err = DecodeUrlImageData(fileMeta.Data) } else { - logger.SysLog(fmt.Sprintf("decoding image")) + common.SysLog(fmt.Sprintf("decoding image")) config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) } if err != nil { @@ -248,6 +247,11 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco if meta == nil { return 0, errors.New("token count meta is nil") } + + if info.RelayFormat == types.RelayFormatOpenAIRealtime { + return 0, nil + } + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) tkm := CountTextToken(meta.CombineText, model) diff --git a/service/user_notify.go b/service/user_notify.go index 1fcc62d3..7c864a1b 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/model" "strings" ) @@ -13,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -26,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -38,14 +37,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // check setting email userEmail = userSetting.NotificationEmail if userEmail == "" { - logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } diff --git a/setting/chat.go b/setting/chat.go index b417af28..bd1e26e3 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/logger" + "one-api/common" ) var Chats = []map[string]string{ @@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - logger.SysError("error marshalling chats: " + err.Error()) + common.SysLog("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/config/config.go b/setting/config/config.go index 2e43e0a7..3af51b14 100644 --- a/setting/config/config.go +++ b/setting/config/config.go @@ -2,7 +2,7 @@ package config import ( "encoding/json" - "one-api/logger" + "one-api/common" "reflect" "strconv" "strings" @@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error { // 如果找到配置项,则更新配置 if len(configMap) > 0 { if err := updateConfigFromMap(config, configMap); err != nil { - logger.SysError("failed to update config " + name + ": " + err.Error()) + common.SysError("failed to update config " + name + ": " + err.Error()) continue } } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index dcb9fae5..141463e1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "math" - "one-api/logger" + "one-api/common" "sync" ) @@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - logger.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 47079850..5993cdee 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -2,7 +2,7 @@ package ratio_setting import ( "encoding/json" - "one-api/logger" + "one-api/common" "sync" ) @@ -89,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - logger.SysError("error marshalling cache ratio: " + err.Error()) + common.SysLog("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index c1a666e9..c42553da 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -3,7 +3,7 @@ package ratio_setting import ( "encoding/json" "errors" - "one-api/logger" + "one-api/common" "sync" ) @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - logger.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - logger.SysError("group ratio not found: " + name) + common.SysLog("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - logger.SysError("error marshalling group-group ratio: " + err.Error()) + common.SysLog("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index bcbe712c..57e4beec 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/logger" + "one-api/common" "sync" ) @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - logger.SysError("error marshalling user groups: " + err.Error()) + common.SysLog("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/relay_format.go b/types/relay_format.go index 4c29d649..6d94a70b 100644 --- a/types/relay_format.go +++ b/types/relay_format.go @@ -12,4 +12,7 @@ const ( RelayFormatOpenAIRealtime = "openai_realtime" RelayFormatRerank = "rerank" RelayFormatEmbedding = "embedding" + + RelayFormatTask = "task" + RelayFormatMjProxy = "mj_proxy" ) diff --git a/types/relay_request.go b/types/relay_request.go deleted file mode 100644 index b9d092f0..00000000 --- a/types/relay_request.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -type RelayRequest struct { - OriginRequest any - Format RelayFormat - PromptTokenCount int -} - -func (r *RelayRequest) CopyOriginRequest() any { - if r.OriginRequest == nil { - return nil - } - switch v := r.OriginRequest.(type) { - case *GeneralOpenAIRequest: - return v.Copy() - case *GeneralClaudeRequest: - return v.Copy() - case *GeneralGeminiRequest: - return v.Copy() - case *GeneralRerankRequest: - return v.Copy() - case *GeneralEmbeddingRequest: - return v.Copy() - default: - return nil - } -} From 89caccd4e0282022ba4754658deb2684eca8a822 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 14 Aug 2025 21:30:03 +0800 Subject: [PATCH 04/12] refactor: enhance quota handling and logging for pre-consume operations --- common/quota.go | 5 +++++ relay/relay-text.go | 22 +++++++++++++++++----- service/pre_consume_quota.go | 8 ++++++-- service/quota.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 common/quota.go diff --git a/common/quota.go b/common/quota.go new file mode 100644 index 00000000..dfd65d27 --- /dev/null +++ b/common/quota.go @@ -0,0 +1,5 @@ +package common + +func GetTrustQuota() int { + return int(10 * QuotaPerUnit) +} diff --git a/relay/relay-text.go b/relay/relay-text.go index 5c07c718..1e5aafd6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -327,11 +327,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage totalTokens := promptTokens + completionTokens var logContent string - if !relayInfo.PriceData.UsePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } // record all the consume log even if quota is 0 if totalTokens == 0 { @@ -350,6 +345,23 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage } quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta))) + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3902ef92..3906414a 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -39,13 +39,16 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if userQuota-preConsumedQuota < 0 { return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } + + trustQuota := common.GetTrustQuota() + relayInfo.UserQuota = userQuota - if userQuota > 100*preConsumedQuota { + if userQuota > trustQuota { // 用户额度充足,判断令牌额度是否充足 if !relayInfo.TokenUnlimited { // 非无限令牌,判断令牌额度是否充足 tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { + if tokenQuota > trustQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) @@ -67,6 +70,7 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if err != nil { return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } + logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) } relayInfo.FinalPreConsumedQuota = preConsumedQuota return preConsumedQuota, nil diff --git a/service/quota.go b/service/quota.go index d6f49d64..9abd0af6 100644 --- a/service/quota.go +++ b/service/quota.go @@ -289,6 +289,21 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { @@ -395,6 +410,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, u } quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { From 7f1f36806574de0541d6134bfd67116b7c89448a Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 14 Aug 2025 22:15:18 +0800 Subject: [PATCH 05/12] refactor: improve channel base URL handling and enhance RelayInfo logging --- controller/channel-test.go | 12 +++--- controller/relay.go | 6 +-- model/channel.go | 6 ++- relay/common/relay_info.go | 76 ++++++++++++++++++++++++++++++++++++ service/pre_consume_quota.go | 2 +- 5 files changed, 92 insertions(+), 10 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 32486a8b..95a4313f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -144,7 +144,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - err = helper.ModelMappedHelper(c, info, nil) + info.InitChannelMeta(c) + + err = helper.ModelMappedHelper(c, info, request) if err != nil { return testResult{ context: c, @@ -166,10 +168,10 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - // 创建一个用于日志的 info 副本,移除 ApiKey - logInfo := *info - logInfo.ApiKey = "" - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + //// 创建一个用于日志的 info 副本,移除 ApiKey + //logInfo := info + //logInfo.ApiKey = "" + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString())) priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) if err != nil { diff --git a/controller/relay.go b/controller/relay.go index 8b67fd89..57955a18 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -133,13 +133,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } - preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newApiErr != nil { + preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { return } defer func() { - if newApiErr != nil { + if newAPIError != nil { service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) } }() diff --git a/model/channel.go b/model/channel.go index af769f63..7c3ff915 100644 --- a/model/channel.go +++ b/model/channel.go @@ -406,7 +406,11 @@ func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" } - return *channel.BaseURL + url := *channel.BaseURL + if url == "" { + url = constant.ChannelBaseURLs[channel.Type] + } + return url } func (channel *Channel) GetModelMapping() string { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 742cd61c..31f9ec6d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -2,6 +2,7 @@ package common import ( "errors" + "fmt" "one-api/common" "one-api/constant" "one-api/dto" @@ -155,6 +156,81 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { info.ChannelMeta = channelMeta } +func (info *RelayInfo) ToString() string { + if info == nil { + return "RelayInfo" + } + + // Basic info + b := &strings.Builder{} + fmt.Fprintf(b, "RelayInfo{ ") + fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat) + fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode) + fmt.Fprintf(b, "IsStream: %t, ", info.IsStream) + fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground) + fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath) + fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName) + fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens) + fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage) + fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing) + fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount) + fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota) + + // User & token info (mask secrets) + fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ", + info.UserId, info.UserEmail, info.UserGroup, info.UsingGroup, info.UserQuota) + fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited) + + // Time info + latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds() + fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ", + info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs) + + // Audio / realtime + if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage { + fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ", + info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools)) + } + + // Reasoning + if info.ReasoningEffort != "" { + fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort) + } + + // Price data (non-sensitive) + if info.PriceData.UsePrice { + fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting()) + } + + // Channel metadata (mask ApiKey) + if info.ChannelMeta != nil { + cm := info.ChannelMeta + fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ", + cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions) + } + + // Responses usage info (non-sensitive) + if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 { + fmt.Fprintf(b, "ResponsesTools{ ") + first := true + for name, tool := range info.ResponsesUsageInfo.BuiltInTools { + if !first { + fmt.Fprintf(b, ", ") + } + first = false + if tool != nil { + fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount) + } else { + fmt.Fprintf(b, "%s: calls=0", name) + } + } + fmt.Fprintf(b, " }, ") + } + + fmt.Fprintf(b, "}") + return b.String() +} + // 定义支持流式选项的通道类型 var streamSupportedChannels = map[int]bool{ constant.ChannelTypeOpenAI: true, diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3906414a..ef466d8d 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -37,7 +37,7 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } if userQuota-preConsumedQuota < 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } trustQuota := common.GetTrustQuota() From 44e9b02b3f08d4f69c024d3d872b357252feac26 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 12:41:05 +0800 Subject: [PATCH 06/12] refactor: enhance error handling and masking for model not found scenarios --- common/str.go | 30 +++++++++++++++++++++++++++++- middleware/distributor.go | 4 ++-- middleware/utils.go | 7 ++++++- types/error.go | 13 ++++++++++++- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/common/str.go b/common/str.go index f5399eab..7d4cdaf0 100644 --- a/common/str.go +++ b/common/str.go @@ -99,12 +99,15 @@ func GetJsonString(data any) string { return string(b) } -// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** // 192.168.1.1 -> ***.***.***.*** +// openai.com -> ***.com +// www.openai.com -> ***.***.com +// api.openai.com -> ***.***.com func MaskSensitiveInfo(str string) string { // Mask URLs urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) @@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string { return result }) + // Mask domain names without protocol (like openai.com, www.openai.com) + domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) + str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { + // Skip if it's already been processed as part of a URL + if strings.Contains(str, "://"+domain) { + return domain + } + + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + + // Handle different domain patterns + if len(parts) == 2 { + // openai.com -> ***.com + return "***." + parts[1] + } else { + // www.openai.com -> ***.***.com + // api.openai.com -> ***.***.com + lastPart := parts[len(parts)-1] + return "***.***." + lastPart + } + }) + // Mask IP addresses ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) str = ipPattern.ReplaceAllString(str, "***.***.***.***") diff --git a/middleware/distributor.go b/middleware/distributor.go index 286a4d1f..28b66a3a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) { // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) // message = "数据库一致性已被破坏,请联系管理员" //} - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound)) return } if channel == nil { - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model)) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound)) return } } diff --git a/middleware/utils.go b/middleware/utils.go index e23bbff7..77d1eb80 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -7,12 +7,17 @@ import ( "one-api/logger" ) -func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) { + codeStr := "" + if len(code) > 0 { + codeStr = code[0] + } userId := c.GetInt("id") c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "type": "new_api_error", + "code": codeStr, }, }) c.Abort() diff --git a/types/error.go b/types/error.go index 2cfeb541..8585461a 100644 --- a/types/error.go +++ b/types/error.go @@ -67,6 +67,7 @@ const ( ErrorCodeBadResponseBody ErrorCode = "bad_response_body" ErrorCodeEmptyResponse ErrorCode = "empty_response" ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" + ErrorCodeModelNotFound ErrorCode = "model_not_found" // sql error ErrorCodeQueryDataError ErrorCode = "query_data_error" @@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string { if e.Err == nil { return string(e.errorCode) } - return common.MaskSensitiveInfo(e.Err.Error()) + errStr := e.Err.Error() + if e.StatusCode == http.StatusServiceUnavailable { + if e.errorCode == ErrorCodeModelNotFound { + errStr = "上游分组模型服务不可用,请稍后再试" + } else { + if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") { + errStr = "上游分组模型服务不可用,请稍后再试" + } + } + } + return common.MaskSensitiveInfo(errStr) } func (e *NewAPIError) SetMessage(message string) { From 03fc89da00ef6f5dda778646e3c98eceb59078ce Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 12:50:27 +0800 Subject: [PATCH 07/12] refactor: add email masking function and enhance RelayInfo logging This commit introduces a new function, MaskEmail, to mask user email addresses in logs, preventing PII leakage. Additionally, the RelayInfo logging has been updated to utilize this new masking function, ensuring sensitive information is properly handled. The channel test logic has also been improved to dynamically determine the relay format based on the request path. --- common/str.go | 18 ++++++++++++++++++ controller/channel-test.go | 8 +++++++- relay/common/relay_info.go | 2 +- service/pre_consume_quota.go | 5 +++-- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/common/str.go b/common/str.go index 7d4cdaf0..a769b8e4 100644 --- a/common/str.go +++ b/common/str.go @@ -99,6 +99,24 @@ func GetJsonString(data any) string { return string(b) } +// MaskEmail masks a user email to prevent PII leakage in logs +// Returns "***masked***" if email is empty, otherwise shows only the domain part +func MaskEmail(email string) string { + if email == "" { + return "***masked***" + } + + // Find the @ symbol + atIndex := strings.Index(email, "@") + if atIndex == -1 { + // No @ symbol found, return masked + return "***masked***" + } + + // Return only the domain part with @ symbol + return "***@" + email[atIndex+1:] +} + // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com diff --git a/controller/channel-test.go b/controller/channel-test.go index 95a4313f..ea37c5bf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -134,7 +134,13 @@ func testChannel(channel *model.Channel, testModel string) testResult { } request := buildTestRequest(testModel) - info, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, request, nil) + // Determine relay format based on request path + relayFormat := types.RelayFormatOpenAI + if c.Request.URL.Path == "/v1/embeddings" { + relayFormat = types.RelayFormatEmbedding + } + + info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) if err != nil { return testResult{ diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 31f9ec6d..1ebb0581 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -178,7 +178,7 @@ func (info *RelayInfo) ToString() string { // User & token info (mask secrets) fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ", - info.UserId, info.UserEmail, info.UserGroup, info.UsingGroup, info.UserQuota) + info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota) fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited) // Time info diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index ef466d8d..964ab665 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -3,14 +3,15 @@ package service import ( "errors" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" "one-api/types" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { From 5fe1ce89ec5f554f416d9326e91baf6397c50538 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 13:20:36 +0800 Subject: [PATCH 08/12] refactor: improve request type validation and enhance sensitive information masking --- common/str.go | 89 ++++++++++++++-------------- controller/relay.go | 7 ++- logger/logger.go | 9 ++- relay/channel/xunfei/relay-xunfei.go | 6 +- relay/claude_handler.go | 2 +- relay/common/relay_info.go | 38 ++++++------ relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 2 +- relay/helper/common.go | 3 + service/sensitive.go | 22 +------ types/error.go | 9 --- 11 files changed, 87 insertions(+), 102 deletions(-) diff --git a/common/str.go b/common/str.go index a769b8e4..511a0a39 100644 --- a/common/str.go +++ b/common/str.go @@ -117,6 +117,48 @@ func MaskEmail(email string) string { return "***@" + email[atIndex+1:] } +// maskHostTail returns the tail parts of a domain/host that should be preserved. +// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD. +func maskHostTail(parts []string) []string { + if len(parts) < 2 { + return parts + } + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + return []string{secondLastPart, lastPart} + } + return []string{lastPart} +} + +// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail. +// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk +func maskHostForURL(host string) string { + parts := strings.Split(host, ".") + if len(parts) < 2 { + return "***" + } + tail := maskHostTail(parts) + return "***." + strings.Join(tail, ".") +} + +// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***. +// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk +func maskHostForPlainDomain(domain string) string { + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + tail := maskHostTail(parts) + numStars := len(parts) - len(tail) + if numStars < 1 { + numStars = 1 + } + stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".") + return stars + "." + strings.Join(tail, ".") +} + // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com @@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string { return urlStr } - // Split host by dots - parts := strings.Split(host, ".") - if len(parts) < 2 { - // If less than 2 parts, just mask the whole host - return u.Scheme + "://***" + u.Path - } - - // Keep the TLD (Top Level Domain) and mask the rest - var maskedHost string - if len(parts) == 2 { - // example.com -> ***.com - maskedHost = "***." + parts[len(parts)-1] - } else { - // Handle cases like sub.domain.co.uk or api.example.com - // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) - lastPart := parts[len(parts)-1] - secondLastPart := parts[len(parts)-2] - - if len(lastPart) == 2 && len(secondLastPart) <= 3 { - // Likely country code TLD like co.uk, com.cn - maskedHost = "***." + secondLastPart + "." + lastPart - } else { - // Regular TLD like .com, .org - maskedHost = "***." + lastPart - } - } + // Mask host with unified logic + maskedHost := maskHostForURL(host) result := u.Scheme + "://" + maskedHost @@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string { // Mask domain names without protocol (like openai.com, www.openai.com) domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { - // Skip if it's already been processed as part of a URL + // Skip if it's already part of a URL to avoid partial masking if strings.Contains(str, "://"+domain) { return domain } - - parts := strings.Split(domain, ".") - if len(parts) < 2 { - return domain - } - - // Handle different domain patterns - if len(parts) == 2 { - // openai.com -> ***.com - return "***." + parts[1] - } else { - // www.openai.com -> ***.***.com - // api.openai.com -> ***.***.com - lastPart := parts[len(parts)-1] - return "***.***." + lastPart - } + return maskHostForPlainDomain(domain) }) // Mask IP addresses diff --git a/controller/relay.go b/controller/relay.go index 57955a18..b0c995fb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { meta := request.GetTokenCountMeta() if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveText(meta.CombineText) - if err != nil { + contains, words := service.CheckSensitiveText(meta.CombineText) + if contains { logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) return @@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } defer func() { - if newAPIError != nil { + // Only return quota if downstream failed and quota was actually pre-consumed + if newAPIError != nil && preConsumedQuota != 0 { service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) } }() diff --git a/logger/logger.go b/logger/logger.go index ca81d624..d59e51cb 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" "io" "log" "one-api/common" @@ -13,6 +11,9 @@ import ( "path/filepath" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) const ( @@ -29,6 +30,9 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { + defer func() { + setupLogWorking = false + }() if *common.LogDir != "" { ok := setupLogLock.TryLock() if !ok { @@ -37,7 +41,6 @@ func SetupLogger() { } defer func() { setupLogLock.Unlock() - setupLogWorking = false }() logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 54ed476f..9d5c190f 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap if err != nil || resp.StatusCode != 101 { return nil, nil, err } + + defer func() { + conn.Close() + }() + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap } dataChan <- response if response.Payload.Choices.Status == 2 { - err := conn.Close() if err != nil { common.SysLog("error closing websocket connection: " + err.Error()) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index ddc424b4..8f846f1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ textRequest, ok := info.Request.(*dto.ClaudeRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request)) } err := helper.ModelMappedHelper(c, info, textRequest) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 1ebb0581..51142ff8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -87,26 +87,24 @@ type RelayInfo struct { UsePrice bool RelayMode int OriginModelName string - //RecodeModelName string - RequestURLPath string - PromptTokens int - //SupportStreamOptions bool - ShouldIncludeUsage bool - DisablePing bool // 是否禁止向下游发送自定义 Ping - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat types.RelayFormat - SendResponseCount int - FinalPreConsumedQuota int // 最终预消耗的配额 + RequestURLPath string + PromptTokens int + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 PriceData types.PriceData diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index f7906cf9..99f0d817 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request)) } err := helper.ModelMappedHelper(c, info, embeddingRequest) diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 3ebe0884..d50fff42 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ request, ok := info.Request.(*dto.GeminiChatRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request)) } // model mapped 模型映射 diff --git a/relay/helper/common.go b/relay/helper/common.go index 5075314d..4b2c51eb 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { } func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { + if ws == nil { + return + } errorObj := &dto.RealtimeEvent{ Type: "error", EventId: GetLocalRealtimeID(c), diff --git a/service/sensitive.go b/service/sensitive.go index b3e3c4d6..25cfd46f 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -2,7 +2,6 @@ package service import ( "errors" - "fmt" "one-api/dto" "one-api/setting" "strings" @@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { return nil, nil } -func CheckSensitiveText(text string) ([]string, error) { - if ok, words := SensitiveWordContains(text); ok { - return words, errors.New("sensitive words detected") - } - return nil, nil -} - -func CheckSensitiveInput(input any) ([]string, error) { - switch v := input.(type) { - case string: - return CheckSensitiveText(v) - case []string: - var builder strings.Builder - for _, s := range v { - builder.WriteString(s) - } - return CheckSensitiveText(builder.String()) - } - return CheckSensitiveText(fmt.Sprintf("%v", input)) +func CheckSensitiveText(text string) (bool, []string) { + return SensitiveWordContains(text) } // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 diff --git a/types/error.go b/types/error.go index 8585461a..07486c27 100644 --- a/types/error.go +++ b/types/error.go @@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string { return string(e.errorCode) } errStr := e.Err.Error() - if e.StatusCode == http.StatusServiceUnavailable { - if e.errorCode == ErrorCodeModelNotFound { - errStr = "上游分组模型服务不可用,请稍后再试" - } else { - if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") { - errStr = "上游分组模型服务不可用,请稍后再试" - } - } - } return common.MaskSensitiveInfo(errStr) } From 2f25e44e60753318faa60d10d8e897971eddac00 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 13:28:03 +0800 Subject: [PATCH 09/12] refactor: update token type handling and improve token counting logic --- common/{logger.go => sys_log.go} | 0 dto/claude.go | 2 +- service/token_counter.go | 10 ++++++++-- 3 files changed, 9 insertions(+), 3 deletions(-) rename common/{logger.go => sys_log.go} (100%) diff --git a/common/logger.go b/common/sys_log.go similarity index 100% rename from common/logger.go rename to common/sys_log.go diff --git a/dto/claude.go b/dto/claude.go index 2b3adf19..48bef659 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -204,7 +204,7 @@ type ClaudeRequest struct { func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { var tokenCountMeta = types.TokenCountMeta{ - TokenType: types.TokenTypeTextNumber, + TokenType: types.TokenTypeTokenizer, MaxTokens: int(c.MaxTokens), } diff --git a/service/token_counter.go b/service/token_counter.go index 43a508c1..314fa593 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -251,9 +251,15 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco if info.RelayFormat == types.RelayFormatOpenAIRealtime { return 0, nil } - + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) - tkm := CountTextToken(meta.CombineText, model) + tkm := 0 + + if meta.TokenType == types.TokenTypeTextNumber { + tkm += utf8.RuneCountInString(meta.CombineText) + } else { + tkm += CountTextToken(meta.CombineText, model) + } if info.RelayFormat == types.RelayFormatOpenAI { tkm += meta.ToolsCount * 8 From 067be3727e95838f6ce21e30fa1baffbcb41c080 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 13:46:46 +0800 Subject: [PATCH 10/12] refactor: simplify domain masking logic by removing URL check --- common/str.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/common/str.go b/common/str.go index 511a0a39..6debce28 100644 --- a/common/str.go +++ b/common/str.go @@ -226,10 +226,6 @@ func MaskSensitiveInfo(str string) string { // Mask domain names without protocol (like openai.com, www.openai.com) domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { - // Skip if it's already part of a URL to avoid partial masking - if strings.Contains(str, "://"+domain) { - return domain - } return maskHostForPlainDomain(domain) }) From cc4f73dc7e125ac6d58949b5c657885c73a3c969 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 14:08:15 +0800 Subject: [PATCH 11/12] refactor: enhance logging messages for user quota handling in pre-consume logic --- service/pre_consume_quota.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 964ab665..08e3f68f 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -16,6 +16,7 @@ import ( func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { if preConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) gopool.Go(func() { relayInfoCopy := *relayInfo @@ -52,13 +53,13 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > trustQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) + logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - logger.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota))) + logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId)) } } From 1d4850e47a8961388da84a840c1b9a2270b21e2e Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 14:15:03 +0800 Subject: [PATCH 12/12] refactor: improve logging for channel operations with detailed context --- model/channel.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/model/channel.go b/model/channel.go index 7c3ff915..a9a23481 100644 --- a/model/channel.go +++ b/model/channel.go @@ -209,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - common.SysLog("failed to unmarshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) } } return otherInfo @@ -218,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - common.SysLog("failed to marshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) return } channel.OtherInfo = string(otherInfoBytes) @@ -492,7 +492,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysLog("failed to update response time: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err)) } } @@ -502,7 +502,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysLog("failed to update balance: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err)) } } @@ -618,7 +618,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - common.SysLog("failed to update ability status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err)) } } }() @@ -646,7 +646,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - common.SysLog("failed to update channel status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err)) return false } } @@ -708,7 +708,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - common.SysLog("failed to update abilities: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err)) } } } @@ -732,7 +732,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysLog("failed to update channel used quota: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err)) } } @@ -825,7 +825,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - common.SysLog("failed to unmarshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -836,7 +836,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysLog("failed to marshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -847,7 +847,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - common.SysLog("failed to unmarshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -858,7 +858,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysLog("failed to marshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) return } channel.OtherSettings = string(settingBytes) @@ -869,7 +869,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - common.SysLog("failed to unmarshal param override: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) } } return paramOverride