diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index ef5d1935..fcfcb0c3 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -5,7 +5,7 @@ import ( _ "embed" "fmt" "github.com/go-redis/redis/v8" - "one-api/common" + "one-api/logger" "sync" ) @@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter { // 预加载脚本 limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { - common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) + logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } instance = &RedisLimiter{ client: r, diff --git a/common/logger.go b/common/logger.go index 0f6dc3c3..478015f0 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,52 +1,12 @@ package common import ( - "context" - "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" - "io" - "log" "os" - "path/filepath" - "sync" "time" ) -const ( - loggerINFO = "INFO" - loggerWarn = "WARN" - loggerError = "ERR" -) - -const maxLogCount = 1000000 - -var logCount int -var setupLogLock sync.Mutex -var setupLogWorking bool - -func SetupLogger() { - if *LogDir != "" { - ok := setupLogLock.TryLock() - if !ok { - log.Println("setup log is already working") - return - } - defer func() { - setupLogLock.Unlock() - setupLogWorking = false - }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) - fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) - } -} - func SysLog(s string) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) @@ -57,67 +17,8 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } -func LogInfo(ctx context.Context, msg string) { - logHelper(ctx, loggerINFO, msg) -} - -func LogWarn(ctx context.Context, msg string) { - logHelper(ctx, loggerWarn, msg) -} - -func LogError(ctx context.Context, msg string) { - logHelper(ctx, loggerError, msg) -} - -func logHelper(ctx context.Context, level string, msg string) { - writer := gin.DefaultErrorWriter - if level == loggerINFO { - writer = gin.DefaultWriter - } - id := ctx.Value(RequestIdKey) - if id == nil { - id = "SYSTEM" - } - now := time.Now() - _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - logCount++ // we don't need accurate count, so no lock here - if logCount > maxLogCount && !setupLogWorking { - logCount = 0 - setupLogWorking = true - gopool.Go(func() { - SetupLogger() - }) - } -} - func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} - -func FormatQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d", quota) - } -} - -// LogJson 仅供测试使用 only for test -func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) - if err != nil { - LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) - return - } - LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr))) -} diff --git a/constant/context_key.go b/constant/context_key.go index b82b19e7..569a0373 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,6 +3,8 @@ package constant type ContextKey string const ( + ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 5152e060..bbf0f97a 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -485,8 +486,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("updating all channels") + logger.SysLog("updating all channels") _ = updateAllChannelsBalance() - common.SysLog("channels update done") + logger.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index 026a863b..ec2e6226 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -13,6 +13,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" @@ -159,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { // 创建一个用于日志的 info 副本,移除 ApiKey logInfo := *info logInfo.ApiKey = "" - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) if err != nil { @@ -279,7 +280,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Group: info.UsingGroup, Other: other, }) - common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return testResult{ context: c, localErr: nil, @@ -461,13 +462,13 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { if frequency <= 0 { - common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") + logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") return } for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d25f199b..d21f5e21 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -3,101 +3,101 @@ package controller import ( - "encoding/json" - "net/http" - "one-api/common" - "one-api/model" - "github.com/gin-gonic/gin" + "encoding/json" + "github.com/gin-gonic/gin" + "net/http" + "one-api/logger" + "one-api/model" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* func MigrateConsoleSetting(c *gin.Context) { - // 读取全部 option - opts, err := model.AllOption() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) - return - } - // 建立 map - valMap := map[string]string{} - for _, o := range opts { - valMap[o.Key] = o.Value - } + // 读取全部 option + opts, err := model.AllOption() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + // 建立 map + valMap := map[string]string{} + for _, o := range opts { + valMap[o.Key] = o.Value + } - // 处理 APIInfo - if v := valMap["ApiInfo"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - if len(arr) > 50 { - arr = arr[:50] - } - bytes, _ := json.Marshal(arr) - model.UpdateOption("console_setting.api_info", string(bytes)) - } - model.UpdateOption("ApiInfo", "") - } - // Announcements 直接搬 - if v := valMap["Announcements"]; v != "" { - model.UpdateOption("console_setting.announcements", v) - model.UpdateOption("Announcements", "") - } - // FAQ 转换 - if v := valMap["FAQ"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - out := []map[string]interface{}{} - for _, item := range arr { - q, _ := item["question"].(string) - if q == "" { - q, _ = item["title"].(string) - } - a, _ := item["answer"].(string) - if a == "" { - a, _ = item["content"].(string) - } - if q != "" && a != "" { - out = append(out, map[string]interface{}{"question": q, "answer": a}) - } - } - if len(out) > 50 { - out = out[:50] - } - bytes, _ := json.Marshal(out) - model.UpdateOption("console_setting.faq", string(bytes)) - } - model.UpdateOption("FAQ", "") - } - // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) - url := valMap["UptimeKumaUrl"] - slug := valMap["UptimeKumaSlug"] - if url != "" && slug != "" { - // 仅当同时存在 URL 与 Slug 时才进行迁移 - groups := []map[string]interface{}{ - { - "id": 1, - "categoryName": "old", - "url": url, - "slug": slug, - "description": "", - }, - } - bytes, _ := json.Marshal(groups) - model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) - } - // 清空旧键内容 - if url != "" { - model.UpdateOption("UptimeKumaUrl", "") - } - if slug != "" { - model.UpdateOption("UptimeKumaSlug", "") - } + // 处理 APIInfo + if v := valMap["ApiInfo"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + if len(arr) > 50 { + arr = arr[:50] + } + bytes, _ := json.Marshal(arr) + model.UpdateOption("console_setting.api_info", string(bytes)) + } + model.UpdateOption("ApiInfo", "") + } + // Announcements 直接搬 + if v := valMap["Announcements"]; v != "" { + model.UpdateOption("console_setting.announcements", v) + model.UpdateOption("Announcements", "") + } + // FAQ 转换 + if v := valMap["FAQ"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + out := []map[string]interface{}{} + for _, item := range arr { + q, _ := item["question"].(string) + if q == "" { + q, _ = item["title"].(string) + } + a, _ := item["answer"].(string) + if a == "" { + a, _ = item["content"].(string) + } + if q != "" && a != "" { + out = append(out, map[string]interface{}{"question": q, "answer": a}) + } + } + if len(out) > 50 { + out = out[:50] + } + bytes, _ := json.Marshal(out) + model.UpdateOption("console_setting.faq", string(bytes)) + } + model.UpdateOption("FAQ", "") + } + // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) + url := valMap["UptimeKumaUrl"] + slug := valMap["UptimeKumaSlug"] + if url != "" && slug != "" { + // 仅当同时存在 URL 与 Slug 时才进行迁移 + groups := []map[string]interface{}{ + { + "id": 1, + "categoryName": "old", + "url": url, + "slug": slug, + "description": "", + }, + } + bytes, _ := json.Marshal(groups) + model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) + } + // 清空旧键内容 + if url != "" { + model.UpdateOption("UptimeKumaUrl", "") + } + if slug != "" { + model.UpdateOption("UptimeKumaSlug", "") + } - // 删除旧键记录 - oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} - model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) + // 删除旧键记录 + oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} + model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) - // 重新加载 OptionMap - model.InitOptionMap() - common.SysLog("console setting migrated") - c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) -} \ No newline at end of file + // 重新加载 OptionMap + model.InitOptionMap() + logger.SysLog("console setting migrated") + c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) +} diff --git a/controller/github.go b/controller/github.go index 881d6dc1..0715a8fe 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" "time" @@ -47,7 +48,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -63,7 +64,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/midjourney.go b/controller/midjourney.go index 30a5a09a..a67d39c2 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() { continue } - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) @@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() { } for channelId, taskIds := range taskChannelM { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } midjourneyChannel, err := model.CacheGetChannel(channelId) if err != nil { - common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) + logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) err := model.MjBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) } continue } @@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() { }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 @@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() { req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := service.GetHttpClient().Do(req) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -147,12 +148,12 @@ func UpdateMidjourneyTaskBulk() { } // 映射 VideoUrl task.VideoUrl = responseItem.VideoUrl - + // 映射 VideoUrls - 将数组序列化为 JSON 字符串 if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) if err != nil { - common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) + logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) task.VideoUrls = "[]" // 失败时设置为空数组 } else { task.VideoUrls = string(videoUrlsStr) @@ -160,10 +161,10 @@ func UpdateMidjourneyTaskBulk() { } else { task.VideoUrls = "" // 空值时清空字段 } - + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { - common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" if task.Quota != 0 { shouldReturnQuota = true @@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() { } err = task.Update() if err != nil { - common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } diff --git a/controller/oidc.go b/controller/oidc.go index df8ea1c4..1e3435a8 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/setting" "one-api/setting/system_setting" @@ -58,7 +59,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res.Body.Close() @@ -69,7 +70,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - common.SysError("OIDC 获取 Token 失败,请检查设置!") + logger.SysError("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -80,12 +81,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - common.SysError("OIDC 获取用户信息失败!请检查设置!") + logger.SysError("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -95,7 +96,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - common.SysError("OIDC 获取用户信息为空!请检查设置!") + logger.SysError("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 0453870d..6fba0aac 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,474 +1,474 @@ package controller import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - "time" + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/logger" + "strings" + "sync" + "time" - "one-api/common" - "one-api/dto" - "one-api/model" - "one-api/setting/ratio_setting" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const ( - defaultTimeoutSeconds = 10 - defaultEndpoint = "/api/ratio_config" - maxConcurrentFetches = 8 + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 ) var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} type upstreamResult struct { - Name string `json:"name"` - Data map[string]any `json:"data,omitempty"` - Err string `json:"err,omitempty"` + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` } func FetchUpstreamRatios(c *gin.Context) { - var req dto.UpstreamRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) - return - } + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } - if req.Timeout <= 0 { - req.Timeout = defaultTimeoutSeconds - } + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } - var upstreams []dto.UpstreamDTO + var upstreams []dto.UpstreamDTO - if len(req.Upstreams) > 0 { - for _, u := range req.Upstreams { - if strings.HasPrefix(u.BaseURL, "http") { - if u.Endpoint == "" { - u.Endpoint = defaultEndpoint - } - u.BaseURL = strings.TrimRight(u.BaseURL, "/") - upstreams = append(upstreams, u) - } - } - } else if len(req.ChannelIDs) > 0 { - intIds := make([]int, 0, len(req.ChannelIDs)) - for _, id64 := range req.ChannelIDs { - intIds = append(intIds, int(id64)) - } - dbChannels, err := model.GetChannelsByIds(intIds) - if err != nil { - common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) - return - } - for _, ch := range dbChannels { - if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { - upstreams = append(upstreams, dto.UpstreamDTO{ - ID: ch.Id, - Name: ch.Name, - BaseURL: strings.TrimRight(base, "/"), - Endpoint: "", - }) - } - } - } + if len(req.Upstreams) > 0 { + for _, u := range req.Upstreams { + if strings.HasPrefix(u.BaseURL, "http") { + if u.Endpoint == "" { + u.Endpoint = defaultEndpoint + } + u.BaseURL = strings.TrimRight(u.BaseURL, "/") + upstreams = append(upstreams, u) + } + } + } else if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + ID: ch.Id, + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } - if len(upstreams) == 0 { - c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) - return - } + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } - var wg sync.WaitGroup - ch := make(chan upstreamResult, len(upstreams)) + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) - sem := make(chan struct{}, maxConcurrentFetches) + sem := make(chan struct{}, maxConcurrentFetches) - client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} - for _, chn := range upstreams { - wg.Add(1) - go func(chItem dto.UpstreamDTO) { - defer wg.Done() + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() - sem <- struct{}{} - defer func() { <-sem }() + sem <- struct{}{} + defer func() { <-sem }() - endpoint := chItem.Endpoint - if endpoint == "" { - endpoint = defaultEndpoint - } else if !strings.HasPrefix(endpoint, "/") { - endpoint = "/" + endpoint - } - fullURL := chItem.BaseURL + endpoint + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL := chItem.BaseURL + endpoint - uniqueName := chItem.Name - if chItem.ID != 0 { - uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) - } + uniqueName := chItem.Name + if chItem.ID != 0 { + uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) + } - ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) - if err != nil { - common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - resp, err := client.Do(httpReq) - if err != nil { - common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) - ch <- upstreamResult{Name: uniqueName, Err: resp.Status} - return - } - // 兼容两种上游接口格式: - // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price - // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 - var body struct { - Success bool `json:"success"` - Data json.RawMessage `json:"data"` - Message string `json:"message"` - } + resp, err := client.Do(httpReq) + if err != nil { + logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) + ch <- upstreamResult{Name: uniqueName, Err: resp.Status} + return + } + // 兼容两种上游接口格式: + // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price + // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 + var body struct { + Success bool `json:"success"` + Data json.RawMessage `json:"data"` + Message string `json:"message"` + } - if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { - common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - if !body.Success { - ch <- upstreamResult{Name: uniqueName, Err: body.Message} - return - } + if !body.Success { + ch <- upstreamResult{Name: uniqueName, Err: body.Message} + return + } - // 尝试按 type1 解析 - var type1Data map[string]any - if err := json.Unmarshal(body.Data, &type1Data); err == nil { - // 如果包含至少一个 ratioTypes 字段,则认为是 type1 - isType1 := false - for _, rt := range ratioTypes { - if _, ok := type1Data[rt]; ok { - isType1 = true - break - } - } - if isType1 { - ch <- upstreamResult{Name: uniqueName, Data: type1Data} - return - } - } + // 尝试按 type1 解析 + var type1Data map[string]any + if err := json.Unmarshal(body.Data, &type1Data); err == nil { + // 如果包含至少一个 ratioTypes 字段,则认为是 type1 + isType1 := false + for _, rt := range ratioTypes { + if _, ok := type1Data[rt]; ok { + isType1 = true + break + } + } + if isType1 { + ch <- upstreamResult{Name: uniqueName, Data: type1Data} + return + } + } - // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 - var pricingItems []struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - CompletionRatio float64 `json:"completion_ratio"` - } - if err := json.Unmarshal(body.Data, &pricingItems); err != nil { - common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} - return - } + // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 + var pricingItems []struct { + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + CompletionRatio float64 `json:"completion_ratio"` + } + if err := json.Unmarshal(body.Data, &pricingItems); err != nil { + logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} + return + } - modelRatioMap := make(map[string]float64) - completionRatioMap := make(map[string]float64) - modelPriceMap := make(map[string]float64) + modelRatioMap := make(map[string]float64) + completionRatioMap := make(map[string]float64) + modelPriceMap := make(map[string]float64) - for _, item := range pricingItems { - if item.QuotaType == 1 { - modelPriceMap[item.ModelName] = item.ModelPrice - } else { - modelRatioMap[item.ModelName] = item.ModelRatio - // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 - completionRatioMap[item.ModelName] = item.CompletionRatio - } - } + for _, item := range pricingItems { + if item.QuotaType == 1 { + modelPriceMap[item.ModelName] = item.ModelPrice + } else { + modelRatioMap[item.ModelName] = item.ModelRatio + // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 + completionRatioMap[item.ModelName] = item.CompletionRatio + } + } - converted := make(map[string]any) + converted := make(map[string]any) - if len(modelRatioMap) > 0 { - ratioAny := make(map[string]any, len(modelRatioMap)) - for k, v := range modelRatioMap { - ratioAny[k] = v - } - converted["model_ratio"] = ratioAny - } + if len(modelRatioMap) > 0 { + ratioAny := make(map[string]any, len(modelRatioMap)) + for k, v := range modelRatioMap { + ratioAny[k] = v + } + converted["model_ratio"] = ratioAny + } - if len(completionRatioMap) > 0 { - compAny := make(map[string]any, len(completionRatioMap)) - for k, v := range completionRatioMap { - compAny[k] = v - } - converted["completion_ratio"] = compAny - } + if len(completionRatioMap) > 0 { + compAny := make(map[string]any, len(completionRatioMap)) + for k, v := range completionRatioMap { + compAny[k] = v + } + converted["completion_ratio"] = compAny + } - if len(modelPriceMap) > 0 { - priceAny := make(map[string]any, len(modelPriceMap)) - for k, v := range modelPriceMap { - priceAny[k] = v - } - converted["model_price"] = priceAny - } + if len(modelPriceMap) > 0 { + priceAny := make(map[string]any, len(modelPriceMap)) + for k, v := range modelPriceMap { + priceAny[k] = v + } + converted["model_price"] = priceAny + } - ch <- upstreamResult{Name: uniqueName, Data: converted} - }(chn) - } + ch <- upstreamResult{Name: uniqueName, Data: converted} + }(chn) + } - wg.Wait() - close(ch) + wg.Wait() + close(ch) - localData := ratio_setting.GetExposedData() + localData := ratio_setting.GetExposedData() - var testResults []dto.TestResult - var successfulChannels []struct { - name string - data map[string]any - } + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } - for r := range ch { - if r.Err != "" { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "error", - Error: r.Err, - }) - } else { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "success", - }) - successfulChannels = append(successfulChannels, struct { - name string - data map[string]any - }{name: r.Name, data: r.Data}) - } - } + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } - differences := buildDifferences(localData, successfulChannels) + differences := buildDifferences(localData, successfulChannels) - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "differences": differences, - "test_results": testResults, - }, - }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) } func buildDifferences(localData map[string]any, successfulChannels []struct { - name string - data map[string]any + name string + data map[string]any }) map[string]map[string]dto.DifferenceItem { - differences := make(map[string]map[string]dto.DifferenceItem) + differences := make(map[string]map[string]dto.DifferenceItem) - allModels := make(map[string]struct{}) - - for _, ratioType := range ratioTypes { - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - for modelName := range localRatio { - allModels[modelName] = struct{}{} - } - } - } - } - - for _, channel := range successfulChannels { - for _, ratioType := range ratioTypes { - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - for modelName := range upstreamRatio { - allModels[modelName] = struct{}{} - } - } - } - } + allModels := make(map[string]struct{}) - confidenceMap := make(map[string]map[string]bool) - - // 预处理阶段:检查pricing接口的可信度 - for _, channel := range successfulChannels { - confidenceMap[channel.name] = make(map[string]bool) - - modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) - completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - - if hasModelRatio && hasCompletionRatio { - // 遍历所有模型,检查是否满足不可信条件 - for modelName := range allModels { - // 默认为可信 - confidenceMap[channel.name][modelName] = true - - // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 - if modelRatioVal, ok := modelRatios[modelName]; ok { - if completionRatioVal, ok := completionRatios[modelName]; ok { - // 转换为float64进行比较 - if modelRatioFloat, ok := modelRatioVal.(float64); ok { - if completionRatioFloat, ok := completionRatioVal.(float64); ok { - if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { - confidenceMap[channel.name][modelName] = false - } - } - } - } - } - } - } else { - // 如果不是从pricing接口获取的数据,则全部标记为可信 - for modelName := range allModels { - confidenceMap[channel.name][modelName] = true - } - } - } + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } - for modelName := range allModels { - for _, ratioType := range ratioTypes { - var localValue interface{} = nil - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - if val, exists := localRatio[modelName]; exists { - localValue = val - } - } - } + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } - upstreamValues := make(map[string]interface{}) - confidenceValues := make(map[string]bool) - hasUpstreamValue := false - hasDifference := false + confidenceMap := make(map[string]map[string]bool) - for _, channel := range successfulChannels { - var upstreamValue interface{} = nil - - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - if val, exists := upstreamRatio[modelName]; exists { - upstreamValue = val - hasUpstreamValue = true - - if localValue != nil && localValue != val { - hasDifference = true - } else if localValue == val { - upstreamValue = "same" - } - } - } - if upstreamValue == nil && localValue == nil { - upstreamValue = "same" - } - - if localValue == nil && upstreamValue != nil && upstreamValue != "same" { - hasDifference = true - } - - upstreamValues[channel.name] = upstreamValue - - confidenceValues[channel.name] = confidenceMap[channel.name][modelName] - } + // 预处理阶段:检查pricing接口的可信度 + for _, channel := range successfulChannels { + confidenceMap[channel.name] = make(map[string]bool) - shouldInclude := false - - if localValue != nil { - if hasDifference { - shouldInclude = true - } - } else { - if hasUpstreamValue { - shouldInclude = true - } - } + modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) + completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - if shouldInclude { - if differences[modelName] == nil { - differences[modelName] = make(map[string]dto.DifferenceItem) - } - differences[modelName][ratioType] = dto.DifferenceItem{ - Current: localValue, - Upstreams: upstreamValues, - Confidence: confidenceValues, - } - } - } - } + if hasModelRatio && hasCompletionRatio { + // 遍历所有模型,检查是否满足不可信条件 + for modelName := range allModels { + // 默认为可信 + confidenceMap[channel.name][modelName] = true - channelHasDiff := make(map[string]bool) - for _, ratioMap := range differences { - for _, item := range ratioMap { - for chName, val := range item.Upstreams { - if val != nil && val != "same" { - channelHasDiff[chName] = true - } - } - } - } + // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 + if modelRatioVal, ok := modelRatios[modelName]; ok { + if completionRatioVal, ok := completionRatios[modelName]; ok { + // 转换为float64进行比较 + if modelRatioFloat, ok := modelRatioVal.(float64); ok { + if completionRatioFloat, ok := completionRatioVal.(float64); ok { + if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { + confidenceMap[channel.name][modelName] = false + } + } + } + } + } + } + } else { + // 如果不是从pricing接口获取的数据,则全部标记为可信 + for modelName := range allModels { + confidenceMap[channel.name][modelName] = true + } + } + } - for modelName, ratioMap := range differences { - for ratioType, item := range ratioMap { - for chName := range item.Upstreams { - if !channelHasDiff[chName] { - delete(item.Upstreams, chName) - delete(item.Confidence, chName) - } - } + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } - allSame := true - for _, v := range item.Upstreams { - if v != "same" { - allSame = false - break - } - } - if len(item.Upstreams) == 0 || allSame { - delete(ratioMap, ratioType) - } else { - differences[modelName][ratioType] = item - } - } + upstreamValues := make(map[string]interface{}) + confidenceValues := make(map[string]bool) + hasUpstreamValue := false + hasDifference := false - if len(ratioMap) == 0 { - delete(differences, modelName) - } - } + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil - return differences + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + + confidenceValues[channel.name] = confidenceMap[channel.name][modelName] + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + Confidence: confidenceValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + delete(item.Confidence, chName) + } + } + + allSame := true + for _, v := range item.Upstreams { + if v != "same" { + allSame = false + break + } + } + if len(item.Upstreams) == 0 || allSame { + delete(ratioMap, ratioType) + } else { + differences[modelName][ratioType] = item + } + } + + if len(ratioMap) == 0 { + delete(differences, modelName) + } + } + + return differences } func GetSyncableChannels(c *gin.Context) { - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } - var syncableChannels []dto.SyncableChannel - for _, channel := range channels { - if channel.GetBaseURL() != "" { - syncableChannels = append(syncableChannels, dto.SyncableChannel{ - ID: channel.Id, - Name: channel.Name, - BaseURL: channel.GetBaseURL(), - Status: channel.Status, - }) - } - } + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": syncableChannels, - }) -} \ No newline at end of file + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} diff --git a/controller/relay.go b/controller/relay.go index d235f550..583ac036 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,21 +2,22 @@ package controller import ( "bytes" - "errors" "fmt" "io" "log" "net/http" "one-api/common" "one-api/constant" - constant2 "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" + relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting" "one-api/types" "strings" @@ -24,81 +25,196 @@ import ( "github.com/gorilla/websocket" ) -func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { +func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { var err *types.NewAPIError - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - err = relay.ImageHelper(c) + err = relay.ImageHelper(c, info) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c) + err = relay.AudioHelper(c, info) case relayconstant.RelayModeRerank: - err = relay.RerankHelper(c, relayMode) + err = relay.RerankHelper(c, info) case relayconstant.RelayModeEmbeddings: - err = relay.EmbeddingHelper(c) + err = relay.EmbeddingHelper(c, info) case relayconstant.RelayModeResponses: - err = relay.ResponsesHelper(c) - case relayconstant.RelayModeGemini: - if strings.Contains(c.Request.URL.Path, "embed") { - err = relay.GeminiEmbeddingHandler(c) - } else { - err = relay.GeminiHelper(c) - } + err = relay.ResponsesHelper(c, info) default: - err = relay.TextHelper(c) + err = relay.TextHelper(c, info) } - - if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) { - // 保存错误日志到mysql中 - userId := c.GetInt("id") - tokenName := c.GetString("token_name") - modelName := c.GetString("original_model") - tokenId := c.GetInt("token_id") - userGroup := c.GetString("group") - channelId := c.GetInt("channel_id") - other := make(map[string]interface{}) - other["error_type"] = err.GetErrorType() - other["error_code"] = err.GetErrorCode() - other["status_code"] = err.StatusCode - other["channel_id"] = channelId - other["channel_name"] = c.GetString("channel_name") - other["channel_type"] = c.GetInt("channel_type") - adminInfo := make(map[string]interface{}) - adminInfo["use_channel"] = c.GetStringSlice("use_channel") - isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) - if isMultiKey { - adminInfo["is_multi_key"] = true - adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) - } - other["admin_info"] = adminInfo - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) - } - return err } -func Relay(c *gin.Context) { - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) +func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { + var err *types.NewAPIError + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c, info) + } else { + err = relay.GeminiHelper(c, info) + } + return err +} + +func Relay(c *gin.Context, relayFormat types.RelayFormat) { + requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError + + var ( + newAPIError *types.NewAPIError + ws *websocket.Conn + ) + + if relayFormat == types.RelayFormatOpenAIRealtime { + var err error + ws, err = upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) + return + } + defer ws.Close() + } + + defer func() { + if newAPIError != nil { + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + helper.WssError(c, ws, newAPIError.ToOpenAIError()) + case types.RelayFormatClaude: + c.JSON(newAPIError.StatusCode, gin.H{ + "type": "error", + "error": newAPIError.ToClaudeError(), + }) + default: + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), + }) + } + } + }() + + request, err := helper.GetAndValidateRequest(c, relayFormat) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + return + } + + //includeUsage := true + //// 判断用户是否需要返回使用情况 + //if textRequest.StreamOptions != nil { + // includeUsage = textRequest.StreamOptions.IncludeUsage + //} + // + //// 如果不支持StreamOptions,将StreamOptions设置为nil + //if !relayInfo.SupportStreamOptions || !textRequest.Stream { + // textRequest.StreamOptions = nil + //} else { + // // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + // if constant.ForceStreamOption { + // textRequest.StreamOptions = &dto.StreamOptions{ + // IncludeUsage: true, + // } + // } + //} + // + //relayInfo.ShouldIncludeUsage = includeUsage + + relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) + return + } + + meta := request.GetTokenCountMeta() + + if setting.ShouldCheckPromptSensitive() { + words, err := service.CheckSensitiveText(meta.CombineText) + if err != nil { + logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) + newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return + } + } + + tokens, err := service.CountRequestToken(c, meta, relayInfo) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed) + return + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) + return + } + + preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newApiErr != nil { + return + } + + defer func() { + if newApiErr != nil { + service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + } + }() for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) newAPIError = err break } - newAPIError = relayRequest(c, relayMode, channel) + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + newAPIError = relay.WssHelper(c, ws) + case types.RelayFormatClaude: + newAPIError = relay.ClaudeHelper(c, relayInfo) + case types.RelayFormatGemini: + newAPIError = geminiRelayHandler(c, relayInfo) + default: + newAPIError = relayHandler(c, relayInfo) + } if newAPIError == nil { - return // 成功处理请求,直接返回 + return + } else { + if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) { + // 保存错误日志到mysql中 + userId := c.GetInt("id") + tokenName := c.GetString("token_name") + modelName := c.GetString("original_model") + tokenId := c.GetInt("token_id") + userGroup := c.GetString("group") + channelId := c.GetInt("channel_id") + other := make(map[string]interface{}) + other["error_type"] = newAPIError.GetErrorType() + other["error_code"] = newAPIError.GetErrorCode() + other["status_code"] = newAPIError.StatusCode + other["channel_id"] = channelId + other["channel_name"] = c.GetString("channel_name") + other["channel_type"] = c.GetInt("channel_type") + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = c.GetStringSlice("use_channel") + isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) + if isMultiKey { + adminInfo["is_multi_key"] = true + adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) + } + other["admin_info"] = adminInfo + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other) + } } go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) @@ -107,21 +223,11 @@ func Relay(c *gin.Context) { break } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "error": newAPIError.ToOpenAIError(), - }) + logger.LogInfo(c, retryLogStr) } } @@ -132,122 +238,6 @@ var upgrader = websocket.Upgrader{ }, } -func WssRelay(c *gin.Context) { - // 将 HTTP 连接升级为 WebSocket 连接 - - ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) - defer ws.Close() - - if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) - return - } - - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = wssRequest(c, ws, relayMode, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - helper.WssError(c, ws, newAPIError.ToOpenAIError()) - } -} - -func RelayClaude(c *gin.Context) { - //relayMode := constant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = claudeRequest(c, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "type": "error", - "error": newAPIError.ToClaudeError(), - }) - } -} - -func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relayHandler(c, relayMode) -} - -func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.WssHelper(c, ws) -} - -func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.ClaudeHelper(c) -} - func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) @@ -270,10 +260,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -327,7 +317,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { service.DisableChannel(channelError, err.Error()) } @@ -362,7 +352,7 @@ func RelayMidjourney(c *gin.Context) { "code": err.Code, }) channelId := c.GetInt("channel_id") - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) } } @@ -404,7 +394,7 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, newAPIError := getChannel(c, group, originalModel, i) if newAPIError != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } @@ -412,7 +402,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, _ := common.GetRequestBody(c) @@ -422,7 +412,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) + logger.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/controller/task.go b/controller/task.go index 5fbdb424..a5b28ae2 100644 --- a/controller/task.go +++ b/controller/task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "sort" @@ -25,7 +26,7 @@ func UpdateTaskBulk() { //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") + logger.SysLog("任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) @@ -54,9 +55,9 @@ func UpdateTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -65,7 +66,7 @@ func UpdateTaskBulk() { UpdateTaskByPlatform(platform, taskChannelM, taskM) } - common.SysLog("任务进度轮询完成") + logger.SysLog("任务进度轮询完成") } } @@ -77,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) default: if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) + logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) } } } @@ -86,27 +87,27 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { - common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) } } return nil } func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } channel, err := model.CacheGetChannel(channelId) if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) err = model.TaskBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -118,27 +119,27 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) return err } if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) return err } @@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } @@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + logger.SysError("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index 914bf6e6..dca42955 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" "io" - "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "one-api/relay/channel" @@ -18,14 +18,14 @@ import ( func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) } } return nil } func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -37,7 +37,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -47,7 +47,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha } for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) } } return nil @@ -61,7 +61,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task := taskM[taskId] if task == nil { - common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) return fmt.Errorf("task %s not found", taskId) } resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ @@ -124,13 +124,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.FinishTime = now } task.FailReason = taskResult.Reason - common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) quota := task.Quota if quota != 0 { if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + logger.LogError(ctx, "Failed to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } default: @@ -140,7 +140,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.Progress = taskResult.Progress } if err := task.Update(); err != nil { - common.SysError("UpdateVideoTask task error: " + err.Error()) + logger.SysError("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index 62eb5474..db575fec 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,6 +3,7 @@ package controller import ( "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" @@ -102,7 +103,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + logger.SysError("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/topup.go b/controller/topup.go index 827dda39..3f3c8623 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -5,6 +5,7 @@ import ( "log" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) { return } log.Printf("易支付回调更新用户成功 %v", topUp) - model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) diff --git a/controller/twofa.go b/controller/twofa.go index 9f48eed8..0ab66029 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/logger" "one-api/model" "strconv" @@ -70,7 +71,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成2FA密钥失败", }) - common.SysError("生成TOTP密钥失败: " + err.Error()) + logger.SysError("生成TOTP密钥失败: " + err.Error()) return } @@ -81,7 +82,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + logger.SysError("生成备用码失败: " + err.Error()) return } @@ -115,7 +116,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + logger.SysError("保存备用码失败: " + err.Error()) return } @@ -294,7 +295,7 @@ func Get2FAStatus(c *gin.Context) { // 获取剩余备用码数量 backupCount, err := model.GetUnusedBackupCodeCount(userId) if err != nil { - common.SysError("获取备用码数量失败: " + err.Error()) + logger.SysError("获取备用码数量失败: " + err.Error()) } else { status["backup_codes_remaining"] = backupCount } @@ -368,7 +369,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + logger.SysError("生成备用码失败: " + err.Error()) return } @@ -378,7 +379,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + logger.SysError("保存备用码失败: " + err.Error()) return } diff --git a/controller/user.go b/controller/user.go index 29cf83e1..8ce44fa6 100644 --- a/controller/user.go +++ b/controller/user.go @@ -7,6 +7,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/setting" "strconv" @@ -192,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -235,7 +236,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + logger.SysError("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - common.SysError("failed to generate key: " + err.Error()) + logger.SysError("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) @@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/dto/audio.go b/dto/audio.go index c36b3da5..81872c69 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,5 +1,11 @@ package dto +import ( + "one-api/types" + + "github.com/gin-gonic/gin" +) + type AudioRequest struct { Model string `json:"model"` Input string `json:"input"` @@ -8,6 +14,18 @@ type AudioRequest struct { ResponseFormat string `json:"response_format,omitempty"` } +func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta { + meta := &types.TokenCountMeta{ + CombineText: r.Input, + TokenType: types.TokenTypeTextNumber, + } + return meta +} + +func (r *AudioRequest) IsStream(c *gin.Context) bool { + return false +} + type AudioResponse struct { Text string `json:"text"` } diff --git a/dto/claude.go b/dto/claude.go index 58a09217..2b3adf19 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -5,6 +5,9 @@ import ( "fmt" "one-api/common" "one-api/types" + "strings" + + "github.com/gin-gonic/gin" ) type ClaudeMetadata struct { @@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string { } func (c *ClaudeMediaMessage) GetJsonRowString() string { - jsonContent, _ := json.Marshal(c) + jsonContent, _ := common.Marshal(c) return string(jsonContent) } @@ -199,6 +202,129 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta = types.TokenCountMeta{ + TokenType: types.TokenTypeTextNumber, + MaxTokens: int(c.MaxTokens), + } + + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + // system + if c.System != nil { + if c.IsStringSystem() { + sys := c.GetStringSystem() + if sys != "" { + texts = append(texts, sys) + } + } else { + systemMedia := c.ParseSystem() + for _, media := range systemMedia { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + } + } + } + } + + // messages + for _, message := range c.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.IsStringContent() { + content := message.GetStringContent() + if content != "" { + texts = append(texts, content) + } + continue + } + + content, _ := message.ParseContent() + for _, media := range content { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + case "tool_use": + if media.Name != "" { + texts = append(texts, media.Name) + } + if media.Input != nil { + b, _ := common.Marshal(media.Input) + texts = append(texts, string(b)) + } + case "tool_result": + if media.Content != nil { + b, _ := common.Marshal(media.Content) + texts = append(texts, string(b)) + } + } + } + } + + // tools + if c.Tools != nil { + tools := c.GetTools() + normalTools, webSearchTools := ProcessTools(tools) + if normalTools != nil { + for _, t := range normalTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.Description != "" { + texts = append(texts, t.Description) + } + if t.InputSchema != nil { + b, _ := common.Marshal(t.InputSchema) + texts = append(texts, string(b)) + } + } + } + if webSearchTools != nil { + for _, t := range webSearchTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.UserLocation != nil { + b, _ := common.Marshal(t.UserLocation) + texts = append(texts, string(b)) + } + } + } + } + + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool { + return claudeRequest.Stream +} + func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { for _, message := range c.Messages { content, _ := message.ParseContent() diff --git a/dto/embedding.go b/dto/embedding.go index 9d722292..fff37776 100644 --- a/dto/embedding.go +++ b/dto/embedding.go @@ -1,5 +1,12 @@ package dto +import ( + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` @@ -24,9 +31,26 @@ type EmbeddingRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } -func (r EmbeddingRequest) ParseInput() []string { +func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + inputs := r.ParseInput() + for _, input := range inputs { + texts = append(texts, input) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + +func (r *EmbeddingRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *EmbeddingRequest) ParseInput() []string { if r.Input == nil { - return nil + return make([]string, 0) } var input []string switch r.Input.(type) { diff --git a/dto/gemini.go b/dto/gemini.go index 6cb3e17a..b327de62 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,7 +2,10 @@ package dto import ( "encoding/json" + "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" + "one-api/types" "strings" ) @@ -14,19 +17,75 @@ type GeminiChatRequest struct { SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } +func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { + var files []*types.FileMeta = make([]*types.FileMeta, 0) + + var maxTokens int + + if r.GenerationConfig.MaxOutputTokens > 0 { + maxTokens = int(r.GenerationConfig.MaxOutputTokens) + } + + var inputTexts []string + for _, content := range r.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + if part.InlineData != nil && part.InlineData.Data != "" { + if strings.HasPrefix(part.InlineData.MimeType, "image/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeAudio, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeVideo, + Data: part.InlineData.Data, + }) + } else { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: part.InlineData.Data, + }) + } + } + } + } + + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + Files: files, + MaxTokens: maxTokens, + } +} + +func (r *GeminiChatRequest) IsStream(c *gin.Context) bool { + if c.Query("alt") == "sse" { + return true + } + return false +} + func (r *GeminiChatRequest) GetTools() []GeminiChatTool { var tools []GeminiChatTool if strings.HasSuffix(string(r.Tools), "[") { // is array if err := common.Unmarshal(r.Tools, &tools); err != nil { - common.LogError(nil, "error_unmarshalling_tools: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_tools: "+err.Error()) return nil } } else if strings.HasPrefix(string(r.Tools), "{") { // is object singleTool := GeminiChatTool{} if err := common.Unmarshal(r.Tools, &singleTool); err != nil { - common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) return nil } tools = []GeminiChatTool{singleTool} @@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) { // Marshal the tools to JSON data, err := common.Marshal(tools) if err != nil { - common.LogError(nil, "error_marshalling_tools: "+err.Error()) + logger.LogError(nil, "error_marshalling_tools: "+err.Error()) return } r.Tools = data diff --git a/dto/dalle.go b/dto/openai_image.go similarity index 51% rename from dto/dalle.go rename to dto/openai_image.go index ce2f6361..7431935b 100644 --- a/dto/dalle.go +++ b/dto/openai_image.go @@ -1,11 +1,17 @@ package dto -import "encoding/json" +import ( + "encoding/json" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` + N uint `json:"n,omitempty"` Size string `json:"size,omitempty"` Quality string `json:"quality,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -18,6 +24,42 @@ type ImageRequest struct { Watermark *bool `json:"watermark,omitempty"` } +func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { + var sizeRatio = 1.0 + var qualityRatio = 1.0 + + if strings.HasPrefix(i.Model, "dall-e") { + // Size + if i.Size == "256x256" { + sizeRatio = 0.4 + } else if i.Size == "512x512" { + sizeRatio = 0.45 + } else if i.Size == "1024x1024" { + sizeRatio = 1 + } else if i.Size == "1024x1792" || i.Size == "1792x1024" { + sizeRatio = 2 + } + + if i.Model == "dall-e-3" && i.Quality == "hd" { + qualityRatio = 2.0 + if i.Size == "1024x1792" || i.Size == "1792x1024" { + qualityRatio = 1.5 + } + } + } + + // not support token count for dalle + return &types.TokenCountMeta{ + CombineText: i.Prompt, + MaxTokens: 1584, + ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N), + } +} + +func (i *ImageRequest) IsStream(c *gin.Context) bool { + return false +} + type ImageResponse struct { Data []ImageData `json:"data"` Created int64 `json:"created"` diff --git a/dto/openai_request.go b/dto/openai_request.go index 7a23ca5c..0c01c503 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,8 +2,12 @@ package dto import ( "encoding/json" + "fmt" "one-api/common" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type ResponseFormat struct { @@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct { Extra map[string]json.RawMessage `json:"-"` } +func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta types.TokenCountMeta + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + texts = append(texts, v) + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + texts = append(texts, str) + } + } + default: + texts = append(texts, fmt.Sprintf("%v", r.Prompt)) + } + } + + if r.Input != nil { + inputs := r.ParseInput() + texts = append(texts, inputs...) + } + + if r.MaxCompletionTokens > r.MaxTokens { + tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens) + } else { + tokenCountMeta.MaxTokens = int(r.MaxTokens) + } + + for _, message := range r.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.Content != nil { + if message.Name != nil { + tokenCountMeta.NameCount++ + texts = append(texts, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == ContentTypeImageURL { + imageUrl := m.GetImageMedia() + if imageUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeImage, + } + meta.Data = imageUrl.Url + meta.Detail = imageUrl.Detail + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeInputAudio { + inputAudio := m.GetInputAudio() + if inputAudio != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeAudio, + } + meta.Data = inputAudio.Data + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeFile { + file := m.GetFile() + if file != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeFile, + } + meta.Data = file.FileData + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeVideoUrl { + videoUrl := m.GetVideoUrl() + if videoUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeVideo, + } + meta.Data = videoUrl.Url + fileMeta = append(fileMeta, meta) + } + } else { + texts = append(texts, m.Text) + } + } + } + } + + if r.Tools != nil { + openaiTools := r.Tools + for _, tool := range openaiTools { + tokenCountMeta.ToolsCount++ + texts = append(texts, tool.Function.Name) + if tool.Function.Description != "" { + texts = append(texts, tool.Function.Description) + } + if tool.Function.Parameters != nil { + texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters)) + } + } + //toolTokens := CountTokenInput(countStr, request.Model) + //tkm += 8 + //tkm += toolTokens + } + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.Marshal(r) @@ -202,10 +316,25 @@ func (m *MediaContent) GetFile() *MessageFile { return nil } +func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { + if m.VideoUrl != nil { + if _, ok := m.VideoUrl.(*MessageVideoUrl); ok { + return m.VideoUrl.(*MessageVideoUrl) + } + if itemMap, ok := m.VideoUrl.(map[string]any); ok { + out := &MessageVideoUrl{ + Url: common.Interface2String(itemMap["url"]), + } + return out + } + } + return nil +} + type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` - MimeType string + Url string `json:"url"` + Detail string `json:"detail"` + //MimeType string } func (m *MessageImageUrl) IsRemoteImage() bool { @@ -233,6 +362,7 @@ const ( ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 + //ContentTypeAudioUrl = "audio_url" ) func (m *Message) GetPrefix() bool { @@ -623,7 +753,7 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` + Input any `json:"input,omitempty"` Include json.RawMessage `json:"include,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` @@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct { Prompt json.RawMessage `json:"prompt,omitempty"` } +func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { + var fileMeta = make([]*types.FileMeta, 0) + var texts = make([]string, 0) + + if r.Input != nil { + inputs := r.ParseInput() + for _, input := range inputs { + if input.Type == "input_image" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: input.ImageUrl, + Detail: input.Detail, + }) + } else if input.Type == "input_file" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: input.FileUrl, + }) + } else { + texts = append(texts, input.Text) + } + } + } + + if len(r.Instructions) > 0 { + texts = append(texts, string(r.Instructions)) + } + + if len(r.Metadata) > 0 { + texts = append(texts, string(r.Metadata)) + } + + if len(r.Text) > 0 { + texts = append(texts, string(r.Text)) + } + + if len(r.ToolChoice) > 0 { + texts = append(texts, string(r.ToolChoice)) + } + + if len(r.Prompt) > 0 { + texts = append(texts, string(r.Prompt)) + } + + if len(r.Tools) > 0 { + toolStr, _ := common.Marshal(r.Tools) + texts = append(texts, string(toolStr)) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + Files: fileMeta, + MaxTokens: int(r.MaxOutputTokens), + } +} + +func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + type Reasoning struct { Effort string `json:"effort,omitempty"` Summary string `json:"summary,omitempty"` } -//type ResponsesToolsCall struct { -// Type string `json:"type"` -// // Web Search -// UserLocation json.RawMessage `json:"user_location,omitempty"` -// SearchContextSize string `json:"search_context_size,omitempty"` -// // File Search -// VectorStoreIds []string `json:"vector_store_ids,omitempty"` -// MaxNumResults uint `json:"max_num_results,omitempty"` -// Filters json.RawMessage `json:"filters,omitempty"` -// // Computer Use -// DisplayWidth uint `json:"display_width,omitempty"` -// DisplayHeight uint `json:"display_height,omitempty"` -// Environment string `json:"environment,omitempty"` -// // Function -// Name string `json:"name,omitempty"` -// Description string `json:"description,omitempty"` -// Parameters json.RawMessage `json:"parameters,omitempty"` -// Function json.RawMessage `json:"function,omitempty"` -// Container json.RawMessage `json:"container,omitempty"` -//} +type MediaInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + FileUrl string `json:"file_url,omitempty"` + ImageUrl string `json:"image_url,omitempty"` + Detail string `json:"detail,omitempty"` // 仅 input_image 有效 +} + +// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput. +// Reference implementation mirrors Message.ParseContent: +// - input can be a string, treated as an input_text item +// - input can be an array of objects with a `type` field +// supported types: input_text, input_image, input_file +func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { + if r.Input == nil { + return nil + } + + var inputs []MediaInput + + // Try string first + if str, ok := r.Input.(string); ok { + inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + return inputs + } + + // Try array of parts + if array, ok := r.Input.([]any); ok { + for _, itemAny := range array { + // Already parsed MediaInput + if media, ok := itemAny.(MediaInput); ok { + inputs = append(inputs, media) + continue + } + // Generic map + item, ok := itemAny.(map[string]any) + if !ok { + continue + } + typeVal, ok := item["type"].(string) + if !ok { + continue + } + switch typeVal { + case "input_text": + text, _ := item["text"].(string) + inputs = append(inputs, MediaInput{Type: "input_text", Text: text}) + case "input_image": + // image_url may be string or object with url field + var imageUrl string + switch v := item["image_url"].(type) { + case string: + imageUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + imageUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl}) + case "input_file": + // file_url may be string or object with url field + var fileUrl string + switch v := item["file_url"].(type) { + case string: + fileUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + fileUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl}) + } + } + } + + return inputs +} diff --git a/dto/request_common.go b/dto/request_common.go new file mode 100644 index 00000000..e5dde8b5 --- /dev/null +++ b/dto/request_common.go @@ -0,0 +1,11 @@ +package dto + +import ( + "github.com/gin-gonic/gin" + "one-api/types" +) + +type Request interface { + GetTokenCountMeta() *types.TokenCountMeta + IsStream(c *gin.Context) bool +} diff --git a/dto/rerank.go b/dto/rerank.go index 5ea68cba..ca4da9e1 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -1,5 +1,12 @@ package dto +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/types" + "strings" +) + type RerankRequest struct { Documents []any `json:"documents"` Query string `json:"query"` @@ -10,6 +17,26 @@ type RerankRequest struct { OverLapTokens int `json:"overlap_tokens,omitempty"` } +func (r *RerankRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + for _, document := range r.Documents { + texts = append(texts, fmt.Sprintf("%v", document)) + } + + if r.Query != "" { + texts = append(texts, r.Query) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + func (r *RerankRequest) GetReturnDocuments() bool { if r.ReturnDocuments == nil { return false diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..ca81d624 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,115 @@ +package logger + +import ( + "context" + "encoding/json" + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "io" + "log" + "one-api/common" + "os" + "path/filepath" + "sync" + "time" +) + +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" + loggerDebug = "DEBUG" +) + +const maxLogCount = 1000000 + +var logCount int +var setupLogLock sync.Mutex +var setupLogWorking bool + +func SetupLogger() { + if *common.LogDir != "" { + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) + } +} + +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func LogDebug(ctx context.Context, msg string) { + if common.DebugEnabled { + logHelper(ctx, loggerDebug, msg) + } +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(common.RequestIdKey) + if id == nil { + id = "SYSTEM" + } + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + logCount++ // we don't need accurate count, so no lock here + if logCount > maxLogCount && !setupLogWorking { + logCount = 0 + setupLogWorking = true + gopool.Go(func() { + SetupLogger() + }) + } +} + +func LogQuota(quota int) string { + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) + } else { + return fmt.Sprintf("%d 点额度", quota) + } +} + +func FormatQuota(quota int) string { + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit) + } else { + return fmt.Sprintf("%d", quota) + } +} + +// LogJson 仅供测试使用 only for test +func LogJson(ctx context.Context, msg string, obj any) { + jsonStr, err := json.Marshal(obj) + if err != nil { + LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) + return + } + LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr))) +} diff --git a/main.go b/main.go index ca3da601..9a5bd652 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/constant" "one-api/controller" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/router" @@ -35,22 +36,22 @@ func main() { err := InitResources() if err != nil { - common.FatalLog("failed to initialize resources: " + err.Error()) + logger.FatalLog("failed to initialize resources: " + err.Error()) return } - common.SysLog("New API " + common.Version + " started") + logger.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } if common.DebugEnabled { - common.SysLog("running in debug mode") + logger.SysLog("running in debug mode") } defer func() { err := model.CloseDB() if err != nil { - common.FatalLog("failed to close database: " + err.Error()) + logger.FatalLog("failed to close database: " + err.Error()) } }() @@ -59,18 +60,18 @@ func main() { common.MemoryCacheEnabled = true } if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { - common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() @@ -89,14 +90,14 @@ func main() { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } @@ -110,7 +111,7 @@ func main() { } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } @@ -119,13 +120,13 @@ func main() { log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) }) go common.Monitor() - common.SysLog("pprof enabled") + logger.SysLog("pprof enabled") } // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - common.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -155,7 +156,7 @@ func main() { } err = server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } @@ -164,14 +165,14 @@ func InitResources() error { // This is a placeholder function for future resource initialization err := godotenv.Load(".env") if err != nil { - common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") - common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } // 加载环境变量 common.InitEnv() - common.SetupLogger() + logger.SetupLogger() // Initialize model settings ratio_setting.InitRatioSettings() @@ -183,7 +184,7 @@ func InitResources() error { // Initialize SQL Database err = model.InitDB() if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) + logger.FatalLog("failed to initialize database: " + err.Error()) return err } diff --git a/middleware/recover.go b/middleware/recover.go index 51fc7190..6c9c7ef6 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/logger" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..a136a900 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/logger" ) type turnstileCheckResponse struct { @@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 082f5657..e23bbff7 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" ) func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { @@ -15,7 +16,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { }, }) c.Abort() - common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) + logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) } func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { @@ -25,5 +26,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri "code": code, }) c.Abort() - common.LogError(c.Request.Context(), description) + logger.LogError(c.Request.Context(), description) } diff --git a/model/ability.go b/model/ability.go index ce2f299c..ac5530d8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strings" "sync" @@ -294,13 +295,13 @@ func FixAbility() (int, int, error) { if common.UsingSQLite { err := DB.Exec("DELETE FROM abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) return 0, 0, err } } else { err := DB.Exec("TRUNCATE TABLE abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) return 0, 0, err } } @@ -320,7 +321,7 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } @@ -328,7 +329,7 @@ func FixAbility() (int, int, error) { for _, channel := range chunk { err = channel.AddAbilities(nil) if err != nil { - common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index 6239f05c..c0d253fc 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/types" "strings" "sync" @@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - common.SysError("failed to unmarshal other info: " + err.Error()) + logger.SysError("failed to unmarshal other info: " + err.Error()) } } return otherInfo @@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - common.SysError("failed to marshal other info: " + err.Error()) + logger.SysError("failed to marshal other info: " + err.Error()) return } channel.OtherInfo = string(otherInfoBytes) @@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } @@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } } }() @@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) return false } } @@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - common.SysError("failed to update abilities: " + err.Error()) + logger.SysError("failed to update abilities: " + err.Error()) } } } @@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } @@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } channel.OtherSettings = string(settingBytes) @@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - common.SysError("failed to unmarshal param override: " + err.Error()) + logger.SysError("failed to unmarshal param override: " + err.Error()) } } return paramOverride diff --git a/model/channel_cache.go b/model/channel_cache.go index 86866e40..22216027 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -6,6 +6,7 @@ import ( "math/rand" "one-api/common" "one-api/constant" + "one-api/logger" "one-api/setting" "one-api/setting/ratio_setting" "sort" @@ -84,13 +85,13 @@ func InitChannelCache() { } channelsIDM = newChannelId2channel channelSyncLock.Unlock() - common.SysLog("channels synced from database") + logger.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") InitChannelCache() } } diff --git a/model/log.go b/model/log.go index 2070cd6f..d9495968 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/logger" "os" "strings" "time" @@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) + logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) username := c.GetString("username") otherStr := common.MapToJsonStr(other) // 判断是否需要记录 IP @@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } } @@ -142,7 +143,6 @@ type RecordConsumeLogParams struct { Quota int `json:"quota"` Content string `json:"content"` TokenId int `json:"token_id"` - UserQuota int `json:"user_quota"` UseTimeSeconds int `json:"use_time_seconds"` IsStream bool `json:"is_stream"` Group string `json:"group"` @@ -150,7 +150,7 @@ type RecordConsumeLogParams struct { } func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) + logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } @@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } if common.DataExportEnabled { gopool.Go(func() { diff --git a/model/main.go b/model/main.go index dbf27152..1e582e1a 100644 --- a/model/main.go +++ b/model/main.go @@ -5,6 +5,7 @@ import ( "log" "one-api/common" "one-api/constant" + "one-api/logger" "os" "strings" "sync" @@ -84,7 +85,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -108,7 +109,7 @@ func CheckSetup() { if setup == nil { // No setup record exists, check if we have a root user if RootUserExists() { - common.SysLog("system is not initialized, but root user exists") + logger.SysLog("system is not initialized, but root user exists") // Create setup record newSetup := Setup{ Version: common.Version, @@ -116,16 +117,16 @@ func CheckSetup() { } err := DB.Create(&newSetup).Error if err != nil { - common.SysLog("failed to create setup record: " + err.Error()) + logger.SysLog("failed to create setup record: " + err.Error()) } constant.Setup = true } else { - common.SysLog("system is not initialized and no root user exists") + logger.SysLog("system is not initialized and no root user exists") constant.Setup = false } } else { // Setup record exists, system is initialized - common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) + logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) constant.Setup = true } } @@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { if dsn != "" { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") if !isLog { common.UsingPostgreSQL = true } else { @@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } if strings.HasPrefix(dsn, "local") { - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") if !isLog { common.UsingSQLite = true } else { @@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") // check parseTime if !strings.Contains(dsn, "parseTime") { if strings.Contains(dsn, "?") { @@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL @@ -216,11 +217,11 @@ func InitDB() (err error) { if common.UsingMySQL { //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = migrateDB() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } @@ -253,11 +254,11 @@ func InitLogDB() (err error) { if !common.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = migrateLOGDB() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } @@ -354,7 +355,7 @@ func migrateDBFast() error { return err } } - common.SysLog("database migrated") + logger.SysLog("database migrated") return nil } @@ -503,6 +504,6 @@ func PingDB() error { } lastPingTime = time.Now() - common.SysLog("Database pinged successfully") + logger.SysLog("Database pinged successfully") return nil } diff --git a/model/option.go b/model/option.go index 5c84d166..8fcd13a8 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/logger" "one-api/setting" "one-api/setting/config" "one-api/setting/operation_setting" @@ -150,7 +151,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -158,7 +159,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/pricing.go b/model/pricing.go index 0936d298..31aa5cdf 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -3,6 +3,7 @@ package model import ( "encoding/json" "fmt" + "one-api/logger" "strings" "one-api/common" @@ -92,7 +93,7 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 diff --git a/model/redemption.go b/model/redemption.go index bf237668..1ab84f45 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strconv" "gorm.io/gorm" @@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id)) return redemption.Quota, nil } diff --git a/model/token.go b/model/token.go index e85a445e..63c17e2d 100644 --- a/model/token.go +++ b/model/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } @@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } @@ -178,7 +179,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + logger.SysError("failed to update token cache: " + err.Error()) } }) } @@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + logger.SysError("failed to update token cache: " + err.Error()) } }) } @@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - common.SysError("failed to delete token cache: " + err.Error()) + logger.SysError("failed to delete token cache: " + err.Error()) } }) } @@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to increase token quota: " + err.Error()) + logger.SysError("failed to increase token quota: " + err.Error()) } }) } @@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to decrease token quota: " + err.Error()) + logger.SysError("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/topup.go b/model/topup.go index c34c0ce6..802c866f 100644 --- a/model/topup.go +++ b/model/topup.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "gorm.io/gorm" ) @@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) { return errors.New("充值失败," + err.Error()) } - RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount)) + RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) return nil } diff --git a/model/twofa.go b/model/twofa.go index d09ff9fe..b2ea54e0 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "time" "gorm.io/gorm" @@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { if !common.ValidateTOTPCode(t.Secret, code) { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + logger.SysError("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + logger.SysError("更新2FA使用记录失败: " + err.Error()) } return true, nil @@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { if !valid { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + logger.SysError("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + logger.SysError("更新2FA使用记录失败: " + err.Error()) } return true, nil diff --git a/model/usedata.go b/model/usedata.go index 1255b0be..f0027a8d 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -4,6 +4,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/logger" "sync" "time" ) @@ -24,12 +25,12 @@ func UpdateQuotaData() { // recover defer func() { if r := recover(); r != nil { - common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) + logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) } }() for { if common.DataExportEnabled { - common.SysLog("正在更新数据看板数据...") + logger.SysLog("正在更新数据看板数据...") SaveQuotaDataCache() } time.Sleep(time.Duration(common.DataExportInterval) * time.Minute) @@ -91,7 +92,7 @@ func SaveQuotaDataCache() { } } CacheQuotaData = make(map[string]*QuotaData) - common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) + logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) } func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) { @@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int, "token_used": gorm.Expr("token_used + ?", tokenUsed), }).Error if err != nil { - common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) + logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) } } diff --git a/model/user.go b/model/user.go index 6021f495..244380ad 100644 --- a/model/user.go +++ b/model/user.go @@ -6,6 +6,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "strconv" "strings" @@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) } } return setting @@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + logger.SysError("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) @@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) { func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { - return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) + return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 @@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } @@ -517,7 +518,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + logger.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - common.SysError("failed to update user quota cache: " + err.Error()) + logger.SysError("failed to update user quota cache: " + err.Error()) } }) } @@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - common.SysError("failed to update user group cache: " + err.Error()) + logger.SysError("failed to update user group cache: " + err.Error()) } }) } @@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - common.SysError("failed to update user setting cache: " + err.Error()) + logger.SysError("failed to update user setting cache: " + err.Error()) } }) } @@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to increase user quota: " + err.Error()) + logger.SysError("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to decrease user quota: " + err.Error()) + logger.SysError("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + logger.SysError("failed to update user used quota and request count: " + err.Error()) return } @@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + logger.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + logger.SysError("failed to update user request count: " + err.Error()) } } @@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - common.SysError("failed to update user name cache: " + err.Error()) + logger.SysError("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index a631457c..dec7597b 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,6 +5,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "time" "github.com/gin-gonic/gin" @@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + logger.SysError("failed to unmarshal setting: " + err.Error()) } } return setting @@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + logger.SysError("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index 1f8a0963..abd96b79 100644 --- a/model/utils.go +++ b/model/utils.go @@ -3,6 +3,7 @@ package model import ( "errors" "one-api/common" + "one-api/logger" "sync" "time" @@ -65,7 +66,7 @@ func batchUpdate() { return } - common.SysLog("batch update started") + logger.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -77,12 +78,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + logger.SysError("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + logger.SysError("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -93,7 +94,7 @@ func batchUpdate() { } } } - common.SysLog("batch update finished") + logger.SysLog("batch update finished") } func RecordExist(err error) (bool, error) { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 88777838..1bc2d90b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -4,107 +4,40 @@ import ( "errors" "fmt" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/types" - "strings" "github.com/gin-gonic/gin" ) -func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { - audioRequest := &dto.AudioRequest{} - err := common.UnmarshalBodyReusable(c, audioRequest) - if err != nil { - return nil, err - } - switch info.RelayMode { - case relayconstant.RelayModeAudioSpeech: - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(audioRequest.Input) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - default: - err = c.Request.ParseForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - if audioRequest.Model == "" { - audioRequest.Model = formData.Get("model") - } +func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - audioRequest.ResponseFormat = formData.Get("response_format") - if audioRequest.ResponseFormat == "" { - audioRequest.ResponseFormat = "json" - } - } - return audioRequest, nil -} - -func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) - audioRequest, err := getAndValidAudioRequest(c, relayInfo) - - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + audioRequest, ok := info.Request.(*dto.AudioRequest) + if !ok { + return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - preConsumedTokens = promptTokens - relayInfo.PromptTokens = promptTokens - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr - } - defer func() { - if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - err = helper.ModelMappedHelper(c, relayInfo, audioRequest) + err := helper.ModelMappedHelper(c, info, audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) + ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - resp, err := adaptor.DoRequest(c, relayInfo, ioReader) + resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } @@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 754f29c8..841896cf 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/service" "one-api/types" @@ -43,7 +43,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - common.SysError("updateTask client.Do err: " + err.Error()) + logger.SysError("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -53,7 +53,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error var response AliResponse err = json.Unmarshal(responseBody, &response) if err != nil { - common.SysError("updateTask NewDecoder err: " + err.Error()) + logger.SysError("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } @@ -109,7 +109,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc if responseFormat == "b64_json" { _, b64, err := service.GetImageFromUrl(data.Url) if err != nil { - common.LogError(c, "get_image_data_failed: "+err.Error()) + logger.LogError(c, "get_image_data_failed: "+err.Error()) continue } b64Json = b64 @@ -134,14 +134,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { - common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 4f448e01..e7d6b514 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index fcf63854..17fcef2a 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -7,7 +7,9 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/helper" + "one-api/service" "strings" "one-api/types" @@ -46,7 +48,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) model := c.GetString("model") if model == "" { @@ -148,7 +150,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -161,7 +163,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -171,7 +173,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -181,7 +183,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 3ccd2d78..fd745cf7 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,6 +7,7 @@ import ( "io" "net/http" common2 "one-api/common" + "one-api/logger" "one-api/relay/common" "one-api/relay/constant" "one-api/relay/helper" @@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { err := helper.PingData(c) if err != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) + logger.LogError(c, "SSE ping error: "+err.Error()) done <- err return } diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index a7cd5996..696c2496 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -118,7 +119,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -129,11 +130,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - common.SysError("error sending stream response: " + err.Error()) + logger.SysError("error sending stream response: " + err.Error()) } return true }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -143,7 +144,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil @@ -168,7 +169,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index e4d3975e..5d839908 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -609,7 +610,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { @@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud err = helper.ObjectData(c, response) if err != nil { - common.LogError(c, "send_stream_response_failed: "+err.Error()) + logger.LogError(c, "send_stream_response_failed: "+err.Error()) } } return nil @@ -653,7 +654,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - common.SysError("claude response usage is not complete, maybe upstream error") + logger.SysError("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } @@ -667,7 +668,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + logger.SysError("send final response failed: " + err.Error()) } } helper.Done(c) @@ -736,12 +737,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) } - common.IOCopyBytesGracefully(c, nil, responseData) + service.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 5e8fe7f9..00f6b6c5 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -5,8 +5,8 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res var response dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &response) if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) continue } for _, choice := range response.Choices { @@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res info.FirstResponseTime = time.Now() } if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) + logger.LogError(c, "error_rendering_stream_response: "+err.Error()) } } if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + logger.LogError(c, "error_scanning_stream_response: "+err.Error()) } usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + logger.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { @@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index fcfb12b7..ccef9b23 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -118,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -153,7 +154,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) @@ -175,7 +176,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { @@ -216,7 +217,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 32cc6937..18ed46af 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -49,7 +50,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) // convert coze response to openai response var response dto.TextResponse var cozeResponse CozeChatDetailResponse @@ -154,7 +155,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } @@ -171,14 +172,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } @@ -203,11 +204,11 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + logger.SysError("error_unmarshalling_stream_response: " + err.Error()) return } - common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 47337127..f03d61a4 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -36,14 +37,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - common.SysError("failed to decode base64: " + err.Error()) + logger.SysError("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - common.SysError("failed to create temp file: " + err.Error()) + logger.SysError("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -51,7 +52,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - common.SysError("failed to write to temp file: " + err.Error()) + logger.SysError("failed to write to temp file: " + err.Error()) return nil } @@ -61,7 +62,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - common.SysError("failed to add user field: " + err.Error()) + logger.SysError("failed to add user field: " + err.Error()) return nil } @@ -74,13 +75,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - common.SysError("failed to create form file: " + err.Error()) + logger.SysError("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - common.SysError("failed to copy file content: " + err.Error()) + logger.SysError("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -88,7 +89,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - common.SysError("failed to create request: " + err.Error()) + logger.SysError("failed to create request: " + err.Error()) return nil } @@ -99,7 +100,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - common.SysError("failed to send request: " + err.Error()) + logger.SysError("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -109,7 +110,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - common.SysError("failed to decode response: " + err.Error()) + logger.SysError("failed to decode response: " + err.Error()) return nil } @@ -219,7 +220,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -239,7 +240,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } return true }) @@ -258,7 +259,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4141caf7..05d974f6 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf }, }, Parameters: dto.GeminiImageParameters{ - SampleCount: request.N, + SampleCount: int(request.N), AspectRatio: aspectRatio, PersonGeneration: "allow_adult", // default allow adult }, diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 7f2f51fb..974a22f5 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,7 +18,7 @@ import ( ) func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) @@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return usage, nil } @@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn // 直接发送 GeminiChatResponse 响应 err = helper.StringData(c, data) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } info.SendResponseCount++ return true diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 58efa1a5..82a2d8de 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * finishReason = constant.FinishReasonToolCalls err = handleStream(c, info, emptyResponse) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } response.ClearToolCalls() @@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * err = handleStream(c, info, response) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } if isStop { _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) @@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := handleFinalStream(c, info, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + logger.SysError("send final response failed: " + err.Error()) } //if info.RelayFormat == relaycommon.RelayFormatOpenAI { // helper.Done(c) @@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } @@ -1057,13 +1058,13 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R break } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { @@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 28af1866..11a0117b 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go index c9db6630..d8b598dc 100644 --- a/relay/channel/jimeng/sign.go +++ b/relay/channel/jimeng/sign.go @@ -12,7 +12,7 @@ import ( "io" "net/http" "net/url" - "one-api/common" + "one-api/logger" "sort" "strings" "time" @@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error { if err != nil { return err } - common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) + logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) payloadHash := sha256.Sum256(body) hexPayloadHash := hex.EncodeToString(payloadHash[:]) c.Set(HexPayloadHashKey, hexPayloadHash) diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 78f96d6d..d91aceb3 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) @@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index d4686ce3..066581fa 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, doResponseBody) + service.IOCopyBytesGracefully(c, resp, doResponseBody) return usage, nil } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 696c5cb0..80973aa1 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" @@ -50,7 +51,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { - common.LogError(c, "failed to unmarshal stream response: "+err.Error()) + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) return err } @@ -63,7 +64,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.LogError(c, "failed to marshal gemini response: "+err.Error()) + logger.LogError(c, "failed to marshal gemini response: "+err.Error()) return err } @@ -110,14 +111,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - common.SysError("error processing stream response: " + err.Error()) + logger.SysError("error processing stream response: " + err.Error()) } } return nil @@ -146,7 +147,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -213,7 +214,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return } @@ -227,7 +228,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream case relaycommon.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return } @@ -245,7 +246,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.SysError("error marshalling gemini response: " + err.Error()) + logger.SysError("error marshalling gemini response: " + err.Error()) return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b8e72273..447e0f31 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) model := info.UpstreamModelName var responseId string @@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - common.SysError("error handling stream format: " + err.Error()) + logger.SysError("error handling stream format: " + err.Error()) } } if len(data) > 0 { @@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re shouldSendLastResp := true if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, &containStreamUsage, info, &shouldSendLastResp); err != nil { - common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) + logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } if info.RelayFormat == relaycommon.RelayFormatOpenAI { @@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re // 处理token计算 if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - common.LogError(c, "error processing tokens: "+err.Error()) + logger.LogError(c, "error processing tokens: "+err.Error()) } if !containStreamUsage { @@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) @@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo responseBody = geminiRespStr } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &simpleResponse.Usage, nil } @@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens @@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel c.Writer.WriteHeaderNow() _, err := io.Copy(c.Writer, resp.Body) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // count tokens by audio file duration audioTokens, err := countAudioTokens(c) @@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) usage := &dto.Usage{} usage.PromptTokens = audioTokens @@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken @@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken @@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. localUsage = &dto.RealtimeUsage{} // print now usage } - common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken @@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. case <-targetClosed: case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil - common.LogError(c, "realtime error: "+err.Error()) + logger.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } @@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // Once we've written to the client, we should not return errors anymore // because the upstream has already consumed resources and returned content diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index bae6fcb6..754a6f44 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -16,7 +17,7 @@ import ( ) func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // read response body var responsesResponse dto.OpenAIResponsesResponse @@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // compute usage usage := dto.Usage{} @@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 9b8bce7d..1264b2b4 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -58,15 +59,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) stopChan <- true return } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -78,7 +79,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) stopChan <- true return } @@ -96,7 +97,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, responseText } @@ -105,7 +106,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { @@ -133,6 +134,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index 2e37ad15..b21faccb 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { @@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 9c04c7ad..1deb33fd 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" @@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - common.SysError(fmt.Sprintf("Get Task error: %v", err)) + logger.SysError(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 78ce6238..d3aeab3f 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -13,6 +13,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -106,7 +107,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } @@ -117,17 +118,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } } if err := scanner.Err(); err != nil { - common.SysError("error reading stream: " + err.Error()) + logger.SysError("error reading stream: " + err.Error()) } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil } @@ -138,7 +139,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -156,7 +157,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d098102..4d4e7b92 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -47,7 +48,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } @@ -63,7 +64,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) } return true }) @@ -74,12 +75,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -101,7 +102,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.IOCopyBytesGracefully(c, resp, encodeJson) + service.IOCopyBytesGracefully(c, resp, encodeJson) return xaiResponse.Usage, nil } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 1a426d50..398bb08d 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/helper" "one-api/types" "strings" @@ -143,7 +144,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -218,20 +219,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + logger.SysError("error closing websocket connection: " + err.Error()) } break } @@ -282,6 +283,6 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } apiVersion = "v1.1" - common.SysLog("api_version not found, using default: " + apiVersion) + logger.SysLog("api_version not found, using default: " + apiVersion) return apiVersion } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 35882ed5..65b662b6 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -8,8 +8,10 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" + "one-api/service" "one-api/types" "strings" "sync" @@ -38,7 +40,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + logger.SysError("invalid zhipu key: " + apikey) return "" } @@ -186,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -195,13 +197,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage @@ -212,7 +214,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } @@ -222,7 +224,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) diff --git a/relay/relay-mj.go b/relay/chat_handler.go similarity index 98% rename from relay/relay-mj.go rename to relay/chat_handler.go index e7f316b9..30bce55c 100644 --- a/relay/relay-mj.go +++ b/relay/chat_handler.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -214,7 +215,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") @@ -300,7 +301,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - common.IOCopyBytesGracefully(c, nil, respBody) + service.IOCopyBytesGracefully(c, nil, respBody) return nil } @@ -521,7 +522,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) @@ -572,7 +573,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - common.SysError("get_channel_null: " + err.Error()) + logger.SysError("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/claude_handler.go b/relay/claude_handler.go index b4bf78ff..ddc424b4 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" @@ -18,68 +17,26 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { - textRequest = &dto.ClaudeRequest{} - err = c.ShouldBindJSON(textRequest) - if err != nil { - return nil, err - } - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - if textRequest.Model == "" { - return nil, errors.New("field model is required") - } - return textRequest, nil -} +func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoClaude(c) + textRequest, ok := info.Request.(*dto.ClaudeRequest) - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateClaudeRequest(c) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) } - if textRequest.Stream { - relayInfo.IsStream = true - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) if textRequest.MaxTokens == 0 { textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) @@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { textRequest.Temperature = common.GetPointer[float64](1.0) } textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") - relayInfo.UpstreamModelName = textRequest.Model + info.UpstreamModelName = textRequest.Model } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) //log.Printf("usage: %v", usage) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + + service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage)) return nil } - -func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5cd9223b..59be0011 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,10 +1,12 @@ package common import ( + "errors" "one-api/common" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" + "one-api/types" "strings" "time" @@ -33,17 +35,6 @@ type ClaudeConvertInfo struct { Done bool } -const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" - RelayFormatOpenAIResponses = "openai_responses" - RelayFormatOpenAIAudio = "openai_audio" - RelayFormatOpenAIImage = "openai_image" - RelayFormatRerank = "rerank" - RelayFormatEmbedding = "embedding" -) - type RerankerInfo struct { Documents []any ReturnDocuments bool @@ -59,61 +50,103 @@ type ResponsesUsageInfo struct { BuiltInTools map[string]*BuildInToolInfo } -type RelayInfo struct { +type ChannelMeta struct { ChannelType int ChannelId int - ChannelIsMultiKey bool // 是否多密钥 - ChannelMultiKeyIndex int // 多密钥索引 - TokenId int - TokenKey string - UserId int - UsingGroup string // 使用的分组 - UserGroup string // 用户所在分组 - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - isFirstResponse bool + ChannelIsMultiKey bool + ChannelMultiKeyIndex int + ChannelBaseUrl string + ApiType int + ApiVersion string + ApiKey string + Organization string + ChannelCreateTime int64 + ParamOverride map[string]interface{} + ChannelSetting dto.ChannelSettings + ChannelOtherSettings dto.ChannelOtherSettings + UpstreamModelName string + IsModelMapped bool +} + +type RelayInfo struct { + TokenId int + TokenKey string + UserId int + UsingGroup string // 使用的分组 + UserGroup string // 用户所在分组 + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + isFirstResponse bool //SendLastReasoningResponse bool - ApiType int IsStream bool IsGeminiBatchEmbedding bool IsPlayground bool UsePrice bool RelayMode int - UpstreamModelName string OriginModelName string //RecodeModelName string - RequestURLPath string - ApiVersion string - PromptTokens int - ApiKey string - Organization string - BaseUrl string - SupportStreamOptions bool - ShouldIncludeUsage bool - DisablePing bool // 是否禁止向下游发送自定义 Ping - IsModelMapped bool - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - ChannelSetting dto.ChannelSettings - ChannelOtherSettings dto.ChannelOtherSettings - ParamOverride map[string]interface{} - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat string - SendResponseCount int - ChannelCreateTime int64 + RequestURLPath string + PromptTokens int + SupportStreamOptions bool + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 + + PriceData types.PriceData + + Request dto.Request + ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo *ResponsesUsageInfo + *ChannelMeta +} + +func (info *RelayInfo) InitChannelMeta(c *gin.Context) { + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + apiType, _ := common.ChannelType2APIType(channelType) + channelMeta := &ChannelMeta{ + ChannelType: channelType, + ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId), + ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), + ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + IsModelMapped: false, + } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + channelMeta.ChannelSetting = channelSetting + } + + channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) + if ok { + channelMeta.ChannelOtherSettings = channelOtherSettings + } + info.ChannelMeta = channelMeta } // 定义支持流式选项的通道类型 @@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{ } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { - info := GenRelayInfo(c) + info := genBaseRelayInfo(c, nil) + info.RelayFormat = types.RelayFormatOpenAIRealtime info.ClientWs = ws info.InputAudioFormat = "pcm16" info.OutputAudioFormat = "pcm16" @@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } -func GenRelayInfoClaude(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatClaude +func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatClaude info.ShouldIncludeUsage = false info.ClaudeConvertInfo = &ClaudeConvertInfo{ LastMessagesType: LastMessageTypeNone, @@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { return info } -func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeRerank - info.RelayFormat = RelayFormatRerank + info.RelayFormat = types.RelayFormatRerank info.RerankerInfo = &RerankerInfo{ - Documents: req.Documents, - ReturnDocuments: req.GetReturnDocuments(), + Documents: request.Documents, + ReturnDocuments: request.GetReturnDocuments(), } return info } -func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIAudio +func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIAudio return info } -func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatEmbedding +func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatEmbedding return info } -func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeResponses - info.RelayFormat = RelayFormatOpenAIResponses + info.RelayFormat = types.RelayFormatOpenAIResponses info.SupportStreamOptions = false info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } - if len(req.Tools) > 0 { - for _, tool := range req.Tools { + if len(request.Tools) > 0 { + for _, tool := range request.Tools { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ ToolName: toolType, @@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel } } } - info.IsStream = req.Stream return info } -func GenRelayInfoGemini(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatGemini +func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatGemini info.ShouldIncludeUsage = false + return info } -func GenRelayInfoImage(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIImage +func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIImage return info } -func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) - channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) +func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAI + return info +} + +func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { + + //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) - tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) - tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) - userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) - tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) if startTime.IsZero() { startTime = time.Now() } + // firstResponseTime = time.Now() - 1 second - apiType, _ := common.ChannelType2APIType(channelType) - info := &RelayInfo{ - UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), - isFirstResponse: true, - RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - TokenKey: tokenKey, - UserId: userId, - UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), - UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), - TokenUnlimited: tokenUnlimited, + Request: request, + + UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), + + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), + + TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), + TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), + TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), + + isFirstResponse: true, + RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), + RequestURLPath: c.Request.URL.String(), + IsStream: request.IsStream(c), + StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), - Organization: c.GetString("channel_organization"), - - ChannelCreateTime: c.GetInt64("channel_create_time"), - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, }, - - ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), - ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") info.RequestURLPath = "/v1" + info.RequestURLPath } - if info.BaseUrl == "" { - info.BaseUrl = constant.ChannelBaseURLs[channelType] - } - if info.ChannelType == constant.ChannelTypeAzure { - info.ApiVersion = GetAPIVersion(c) - } - if info.ChannelType == constant.ChannelTypeVertexAi { - info.ApiVersion = c.GetString("region") - } - if streamSupportedChannels[info.ChannelType] { - info.SupportStreamOptions = true - } - - channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) - if ok { - info.ChannelSetting = channelSetting - } - - channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) - if ok { - info.ChannelOtherSettings = channelOtherSettings - } userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) if ok { @@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { return info } -func (info *RelayInfo) SetPromptTokens(promptTokens int) { - info.PromptTokens = promptTokens +func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { + switch relayFormat { + case types.RelayFormatOpenAI: + return GenRelayInfoOpenAI(c, request), nil + case types.RelayFormatOpenAIAudio: + return GenRelayInfoOpenAIAudio(c, request), nil + case types.RelayFormatOpenAIImage: + return GenRelayInfoImage(c, request), nil + case types.RelayFormatOpenAIRealtime: + return GenRelayInfoWs(c, ws), nil + case types.RelayFormatClaude: + return GenRelayInfoClaude(c, request), nil + case types.RelayFormatRerank: + if request, ok := request.(*dto.RerankRequest); ok { + return GenRelayInfoRerank(c, request), nil + } + return nil, errors.New("request is not a RerankRequest") + case types.RelayFormatGemini: + return GenRelayInfoGemini(c, request), nil + case types.RelayFormatEmbedding: + return GenRelayInfoEmbedding(c, request), nil + case types.RelayFormatOpenAIResponses: + if request, ok := request.(*dto.OpenAIResponsesRequest); ok { + return GenRelayInfoResponses(c, request), nil + } + return nil, errors.New("request is not a OpenAIResponsesRequest") + default: + return nil, errors.New("invalid relay format") + } } -func (info *RelayInfo) SetIsStream(isStream bool) { - info.IsStream = isStream +func (info *RelayInfo) SetPromptTokens(promptTokens int) { + info.PromptTokens = promptTokens } func (info *RelayInfo) SetFirstResponseTime() { diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 57df5fe3..05dbfa6d 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println("reranker response body: ", string(responseBody)) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fef8d2c9..f7906cf9 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/types" @@ -16,69 +15,27 @@ import ( "github.com/gin-gonic/gin" ) -func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) - return token -} +func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error { - if embeddingRequest.Input == nil { - return fmt.Errorf("input is empty") - } - if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { - embeddingRequest.Model = "omni-moderation-latest" - } - if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { - embeddingRequest.Model = c.Param("model") - } - return nil -} + info.InitChannelMeta(c) -func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoEmbedding(c) - - var embeddingRequest *dto.EmbeddingRequest - err := common.UnmarshalBodyReusable(c, &embeddingRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) } - err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) + err := helper.ModelMappedHelper(c, info, embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getEmbeddingPromptToken(*embeddingRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e0581156..3ebe0884 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,17 +2,16 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,64 +19,6 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if len(request.Contents) == 0 { - return nil, errors.New("contents is required") - } - return request, nil -} - -// 流模式 -// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx -func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { - if c.Query("alt") == "sse" { - relayInfo.IsStream = true - } - - // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { - // relayInfo.IsStream = true - // } -} - -func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) { - var inputTexts []string - for _, content := range textRequest.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - if len(inputTexts) == 0 { - return nil, nil - } - - sensitiveWords, err := service.CheckSensitiveInput(inputTexts) - return sensitiveWords, err -} - -func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int { - // 计算输入 token 数量 - var inputTexts []string - for _, content := range req.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - - inputText := strings.Join(inputTexts, "\n") - inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) - info.PromptTokens = inputTokens - return inputTokens -} - func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget @@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string { return modelName } -func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateGeminiRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } +func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoGemini(c) - - // 检查 Gemini 流式模式 - checkGeminiStreamMode(c, relayInfo) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } + request, ok := info.Request.(*dto.GeminiChatRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getGeminiInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if isNoThinkingRequest(req) { + if isNoThinkingRequest(request) { // check is thinking - if !strings.Contains(relayInfo.OriginModelName, "-nothinking") { + if !strings.Contains(info.OriginModelName, "-nothinking") { // try to get no thinking model price - noThinkingModelName := relayInfo.OriginModelName + "-nothinking" + noThinkingModelName := info.OriginModelName + "-nothinking" containPrice := helper.ContainPriceOrRatio(noThinkingModelName) if containPrice { - relayInfo.OriginModelName = noThinkingModelName - relayInfo.UpstreamModelName = noThinkingModelName + info.OriginModelName = noThinkingModelName + info.UpstreamModelName = noThinkingModelName } } } - if req.GenerationConfig.ThinkingConfig == nil { - gemini.ThinkingAdaptor(req, relayInfo) + if request.GenerationConfig.ThinkingConfig == nil { + gemini.ThinkingAdaptor(request, info) } } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) // Clean up empty system instruction - if req.SystemInstructions != nil { + if request.SystemInstructions != nil { hasContent := false - for _, part := range req.SystemInstructions.Parts { + for _, part := range request.SystemInstructions.Parts { if part.Text != "" { hasContent = true break } } if !hasContent { - req.SystemInstructions = nil + request.SystemInstructions = nil } } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { requestBody = bytes.NewReader(body) } else { // 使用 ConvertGeminiRequest 转换请求格式 - convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req) + convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("Gemini request body: %s", string(jsonData)) - } + logger.LogDebug(c, "Gemini request body: "+string(jsonData)) + requestBody = bytes.NewReader(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } -func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoGemini(c) +func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") - relayInfo.IsGeminiBatchEmbedding = isBatch + info.IsGeminiBatchEmbedding = isBatch - var promptTokens int var req any var err error var inputTexts []string @@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } } - promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName) - relayInfo.SetPromptTokens(promptTokens) - c.Set("prompt_tokens", promptTokens) - err = helper.ModelMappedHelper(c, relayInfo, req) + err = helper.ModelMappedHelper(c, info, req) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader jsonData, err := common.Marshal(req) @@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(jsonData) - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index c8edb798..5075314d 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "github.com/gin-gonic/gin" @@ -100,7 +101,7 @@ func Done(c *gin.Context) { func WssString(c *gin.Context, ws *websocket.Conn, str string) error { if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) @@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { return fmt.Errorf("error marshalling object: %w", err) } if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index c1735149..e894e228 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,9 +4,10 @@ import ( "encoding/json" "errors" "fmt" - common2 "one-api/common" "one-api/dto" + common2 "one-api/logger" "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro } if request != nil { switch info.RelayFormat { - case common.RelayFormatGemini: + case types.RelayFormatGemini: // Gemini 模型映射 - case common.RelayFormatClaude: + case types.RelayFormatClaude: if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { claudeRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIResponses: + case types.RelayFormatOpenAIResponses: if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { openAIResponsesRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIAudio: + case types.RelayFormatOpenAIAudio: if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { openAIAudioRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIImage: + case types.RelayFormatOpenAIImage: if imageRequest, ok := request.(*dto.ImageRequest); ok { imageRequest.Model = info.UpstreamModelName } - case common.RelayFormatRerank: + case types.RelayFormatRerank: if rerankRequest, ok := request.(*dto.RerankRequest); ok { rerankRequest.Model = info.UpstreamModelName } - case common.RelayFormatEmbedding: + case types.RelayFormatEmbedding: if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { embeddingRequest.Model = info.UpstreamModelName } diff --git a/relay/helper/price.go b/relay/helper/price.go index e80578e5..89fc3b66 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,35 +5,14 @@ import ( "one-api/common" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" + "one-api/types" "github.com/gin-gonic/gin" ) -type GroupRatioInfo struct { - GroupRatio float64 - GroupSpecialRatio float64 - HasSpecialRatio bool -} - -type PriceData struct { - ModelPrice float64 - ModelRatio float64 - CompletionRatio float64 - CacheRatio float64 - CacheCreationRatio float64 - ImageRatio float64 - UsePrice bool - ShouldPreConsumedQuota int - GroupRatioInfo GroupRatioInfo -} - -func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) -} - // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present -func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { - groupRatioInfo := GroupRatioInfo{ +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo { + groupRatioInfo := types.GroupRatioInfo{ GroupRatio: 1.0, // default ratio GroupSpecialRatio: -1, } @@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR return groupRatioInfo } -func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) { modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false) groupRatioInfo := HandleGroupRatio(c, info) @@ -75,8 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var cacheCreationRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota - if maxTokens != 0 { - preConsumedTokens = promptTokens + maxTokens + if meta.MaxTokens != 0 { + preConsumedTokens = promptTokens + meta.MaxTokens } var success bool var matchName string @@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens acceptUnsetRatio = true } if !acceptUnsetRatio { - return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) + return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) } } completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) @@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { + if meta.ImagePriceRatio != 0 { + modelPrice = modelPrice * meta.ImagePriceRatio + } preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } - priceData := PriceData{ + priceData := types.PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, @@ -115,38 +97,32 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens if common.DebugEnabled { println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting())) } - + info.PriceData = priceData return priceData, nil } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { - groupRatioInfo := HandleGroupRatio(c, info) - - modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := PerCallPriceData{ - ModelPrice: modelPrice, - Quota: quota, - GroupRatioInfo: groupRatioInfo, - } - return priceData -} +//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +// groupRatioInfo := HandleGroupRatio(c, info) +// +// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) +// // 如果没有配置价格,则使用默认价格 +// if !success { +// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] +// if !ok { +// modelPrice = 0.1 +// } else { +// modelPrice = defaultPrice +// } +// } +// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) +// priceData := types.PerCallPriceData{ +// ModelPrice: modelPrice, +// Quota: quota, +// GroupRatioInfo: groupRatioInfo, +// } +// return priceData +//} func ContainPriceOrRatio(modelName string) bool { _, ok := ratio_setting.GetModelPrice(modelName, false) diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a5706f95..725d178c 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/setting/operation_setting" "strings" @@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-done: case <-time.After(5 * time.Second): - common.LogError(c, "timeout waiting for goroutines to exit") + logger.LogError(c, "timeout waiting for goroutines to exit") } close(stopChan) @@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) common.SafeSendBool(stopChan, true) } if common.DebugEnabled { @@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case err := <-done: if err != nil { - common.LogError(c, "ping data error: "+err.Error()) + logger.LogError(c, "ping data error: "+err.Error()) return } if common.DebugEnabled { println("ping data sent") } case <-time.After(10 * time.Second): - common.LogError(c, "ping data send timeout") + logger.LogError(c, "ping data send timeout") return case <-ctx.Done(): return @@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon // 监听客户端断开连接 return case <-pingTimeout.C: - common.LogError(c, "ping goroutine max duration reached") + logger.LogError(c, "ping goroutine max duration reached") return } } @@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) } common.SafeSendBool(stopChan, true) if common.DebugEnabled { @@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } case <-time.After(10 * time.Second): - common.LogError(c, "data handler timeout") + logger.LogError(c, "data handler timeout") return case <-ctx.Done(): return @@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon if err := scanner.Err(); err != nil { if err != io.EOF { - common.LogError(c, "scanner error: "+err.Error()) + logger.LogError(c, "scanner error: "+err.Error()) } } }) @@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-ticker.C: // 超时处理逻辑 - common.LogError(c, "streaming timeout") + logger.LogError(c, "streaming timeout") case <-stopChan: // 正常结束 - common.LogInfo(c, "streaming finished") + logger.LogInfo(c, "streaming finished") case <-c.Request.Context().Done(): // 客户端断开连接 - common.LogInfo(c, "client disconnected") + logger.LogInfo(c, "client disconnected") } } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go new file mode 100644 index 00000000..0bc51774 --- /dev/null +++ b/relay/helper/valid_request.go @@ -0,0 +1,301 @@ +package helper + +import ( + "errors" + "fmt" + "math" + "one-api/common" + "one-api/dto" + "one-api/logger" + relayconstant "one-api/relay/constant" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) + + switch format { + case types.RelayFormatOpenAI: + request, err = GetAndValidateTextRequest(c, relayMode) + case types.RelayFormatGemini: + request, err = GetAndValidateGeminiRequest(c) + case types.RelayFormatClaude: + request, err = GetAndValidateClaudeRequest(c) + case types.RelayFormatOpenAIResponses: + request, err = GetAndValidateResponsesRequest(c) + + case types.RelayFormatOpenAIImage: + request, err = GetAndValidOpenAIImageRequest(c, relayMode) + case types.RelayFormatEmbedding: + request, err = GetAndValidateEmbeddingRequest(c, relayMode) + case types.RelayFormatRerank: + request, err = GetAndValidateRerankRequest(c) + case types.RelayFormatOpenAIAudio: + request, err = GetAndValidAudioRequest(c, relayMode) + case types.RelayFormatOpenAIRealtime: + // nothing to do, no request body + default: + return nil, fmt.Errorf("unsupported relay format: %s", format) + } + return request, err +} + +func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err + } + switch relayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + default: + err = c.Request.ParseForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + if audioRequest.Model == "" { + audioRequest.Model = formData.Get("model") + } + + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + audioRequest.ResponseFormat = formData.Get("response_format") + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { + var rerankRequest *dto.RerankRequest + err := common.UnmarshalBodyReusable(c, &rerankRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if rerankRequest.Query == "" { + return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + if len(rerankRequest.Documents) == 0 { + return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + return rerankRequest, nil +} + +func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if embeddingRequest.Input == nil { + return nil, fmt.Errorf("input is empty") + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "omni-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + return embeddingRequest, nil +} + +func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if request.Input == nil { + return nil, errors.New("input is required") + } + return request, nil +} + +func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + + switch relayMode { + case relayconstant.RelayModeImagesEdits: + _, err := c.MultipartForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = uint(common.String2Int(formData.Get("n"))) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") + + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + + watermark := formData.Has("watermark") + if watermark { + imageRequest.Watermark = &watermark + } + default: + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-3" + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + + if imageRequest.N == 0 { + imageRequest.N = 1 + } + } + + return imageRequest, nil +} + +func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + + //if textRequest.Stream { + // relayInfo.IsStream = true + //} + + return textRequest, nil +} + +func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { + textRequest := &dto.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + + if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + + if textRequest.MaxTokens > math.MaxInt32/2 { + return nil, errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return nil, errors.New("model is required") + } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } + switch relayMode { + case relayconstant.RelayModeCompletions: + if textRequest.Prompt == "" { + return nil, errors.New("field prompt is required") + } + case relayconstant.RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + case relayconstant.RelayModeEmbeddings: + case relayconstant.RelayModeModerations: + if textRequest.Input == nil || textRequest.Input == "" { + return nil, errors.New("field input is required") + } + case relayconstant.RelayModeEdits: + if textRequest.Instruction == "" { + return nil, errors.New("field instruction is required") + } + } + return textRequest, nil +} + +func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { + + request := &dto.GeminiChatRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + + //if c.Query("alt") == "sse" { + // relayInfo.IsStream = true + //} + + return request, nil +} diff --git a/relay/image_handler.go b/relay/image_handler.go index f0b69699..008a979d 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -3,19 +3,15 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" - "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -23,183 +19,41 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { - imageRequest := &dto.ImageRequest{} +func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - switch info.RelayMode { - case relayconstant.RelayModeImagesEdits: - _, err := c.MultipartForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - imageRequest.Prompt = formData.Get("prompt") - imageRequest.Model = formData.Get("model") - imageRequest.N = common.String2Int(formData.Get("n")) - imageRequest.Quality = formData.Get("quality") - imageRequest.Size = formData.Get("size") + info.InitChannelMeta(c) - if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } + imageRequest, ok := info.Request.(*dto.ImageRequest) - if info.ApiType == constant.APITypeVolcEngine { - watermark := formData.Has("watermark") - imageRequest.Watermark = &watermark - } - default: - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" - } - - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "auto" - } - } - - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request)) } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - return imageRequest, nil -} - -func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoImage(c) - - imageRequest, err := getAndValidImageRequest(c, relayInfo) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, imageRequest) + err := helper.ModelMappedHelper(c, info, imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - var preConsumedQuota int - var quota int - var userQuota int - if !priceData.UsePrice { - // modelRatio 16 = modelPrice $0.04 - // per 1 modelRatio = $0.04 / 16 - // priceData.ModelPrice = 0.0025 * priceData.ModelRatio - preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - } else { - sizeRatio := 1.0 - qualityRatio := 1.0 - - if strings.HasPrefix(imageRequest.Model, "dall-e") { - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 0.4 - } else if imageRequest.Size == "512x512" { - sizeRatio = 0.45 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1 - } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2 - } - - if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { - qualityRatio = 2.0 - if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - qualityRatio = 1.5 - } - } - } - - // reset model price - priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) - userQuota, err = model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) - } - } - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) + convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { + if info.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) @@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) @@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } if usage.(*dto.Usage).TotalTokens == 0 { - usage.(*dto.Usage).TotalTokens = imageRequest.N + usage.(*dto.Usage).TotalTokens = int(imageRequest.N) } if usage.(*dto.Usage).PromptTokens == 0 { - usage.(*dto.Usage).PromptTokens = imageRequest.N + usage.(*dto.Usage).PromptTokens = int(imageRequest.N) } + quality := "standard" if imageRequest.Quality == "hd" { quality = "hd" } - logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent) + var logContent string + + if len(imageRequest.Size) > 0 { + logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) + } + + postConsumeQuota(c, info, usage.(*dto.Usage), logContent) return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 50d574f3..de750e76 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -2,172 +2,56 @@ package relay import ( "bytes" - "errors" "fmt" "io" - "math" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/setting/operation_setting" "one-api/types" "strings" "time" - "github.com/bytedance/gopkg/util/gopool" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) -func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { - textRequest := &dto.GeneralOpenAIRequest{} - err := common.UnmarshalBodyReusable(c, textRequest) - if err != nil { - return nil, err - } - if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } +func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - if textRequest.MaxTokens > math.MaxInt32/2 { - return nil, errors.New("max_tokens is invalid") - } - if textRequest.Model == "" { - return nil, errors.New("model is required") - } - if textRequest.WebSearchOptions != nil { - if textRequest.WebSearchOptions.SearchContextSize != "" { - validSizes := map[string]bool{ - "high": true, - "medium": true, - "low": true, - } - if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { - return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") - } - } else { - textRequest.WebSearchOptions.SearchContextSize = "medium" - } - } - switch relayInfo.RelayMode { - case relayconstant.RelayModeCompletions: - if textRequest.Prompt == "" { - return nil, errors.New("field prompt is required") - } - case relayconstant.RelayModeChatCompletions: - if len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - case relayconstant.RelayModeEmbeddings: - case relayconstant.RelayModeModerations: - if textRequest.Input == nil || textRequest.Input == "" { - return nil, errors.New("field input is required") - } - case relayconstant.RelayModeEdits: - if textRequest.Instruction == "" { - return nil, errors.New("field instruction is required") - } - } - relayInfo.IsStream = textRequest.Stream - return textRequest, nil -} + info.InitChannelMeta(c) -func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest) - relayInfo := relaycommon.GenRelayInfo(c) - - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + //return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request) } if textRequest.WebSearchOptions != nil { c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) } - if setting.ShouldCheckPromptSensitive() { - words, err := checkRequestSensitive(textRequest, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - // 获取 promptTokens,如果上下文中已经存在,则直接使用 - var promptTokens int - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens = value.(int) - relayInfo.PromptTokens = promptTokens - } else { - promptTokens, err = getPromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newApiErr != nil { - return newApiErr - } - defer func() { - if newApiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - includeUsage := true - // 判断用户是否需要返回使用情况 - if textRequest.StreamOptions != nil { - includeUsage = textRequest.StreamOptions.IncludeUsage - } - - // 如果不支持StreamOptions,将StreamOptions设置为nil - if !relayInfo.SupportStreamOptions || !textRequest.Stream { - textRequest.StreamOptions = nil - } else { - // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions - if constant.ForceStreamOption { - textRequest.StreamOptions = &dto.StreamOptions{ - IncludeUsage: true, - } - } - } - - relayInfo.ShouldIncludeUsage = includeUsage - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -177,12 +61,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.ChannelSetting.SystemPrompt != "" { + if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 request := convertedRequest.(*dto.GeneralOpenAIRequest) containSystemPrompt := false @@ -196,22 +80,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // 如果没有系统提示,则添加系统提示 systemMessage := dto.Message{ Role: request.GetSystemRoleName(), - Content: relayInfo.ChannelSetting.SystemPrompt, + Content: info.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) - } else if relayInfo.ChannelSetting.SystemPromptOverride { + } else if info.ChannelSetting.SystemPromptOverride { common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) // 如果有系统提示,且允许覆盖,则拼接到前面 for i, message := range request.Messages { if message.Role == request.GetSystemRoleName() { if message.IsStringContent() { - request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) } else { contents := message.ParseContent() contents = append([]dto.MediaContent{ { Type: dto.ContentTypeText, - Text: relayInfo.ChannelSetting.SystemPrompt, + Text: info.ChannelSetting.SystemPrompt, }, }, contents...) request.Messages[i].Content = contents @@ -228,10 +112,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -240,14 +124,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("requestBody: ", string(jsonData)) - } + logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) + requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -256,125 +139,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr = service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } - usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenChatRequest(info, *textRequest) - case relayconstant.RelayModeCompletions: - promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) - case relayconstant.RelayModeModerations: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - case relayconstant.RelayModeEmbeddings: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - default: - err = errors.New("unknown relay mode") - promptTokens = 0 - } - info.PromptTokens = promptTokens - return promptTokens, err -} - -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { - var err error - var words []string - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - words, err = service.CheckSensitiveMessages(textRequest.Messages) - case relayconstant.RelayModeCompletions: - words, err = service.CheckSensitiveInput(textRequest.Prompt) - case relayconstant.RelayModeModerations: - words, err = service.CheckSensitiveInput(textRequest.Input) - case relayconstant.RelayModeEmbeddings: - words, err = service.CheckSensitiveInput(textRequest.Input) - } - return words, err -} - -// 预扣费并返回用户剩余配额 -func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - relayInfo.UserQuota = userQuota - if userQuota > 100*preConsumedQuota { - // 用户额度充足,判断令牌额度是否充足 - if !relayInfo.TokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) - } - } - - if preConsumedQuota > 0 { - err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) - } - } - return preConsumedQuota, userQuota, nil -} - -func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { - if preConsumedQuota != 0 { - gopool.Go(func() { - relayInfoCopy := *relayInfo - - err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) - if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) - } - }) - } -} - -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -392,12 +181,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - cacheRatio := priceData.CacheRatio - imageRatio := priceData.ImageRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice + completionRatio := relayInfo.PriceData.CompletionRatio + cacheRatio := relayInfo.PriceData.CacheRatio + imageRatio := relayInfo.PriceData.ImageRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -470,7 +259,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, var audioInputQuota decimal.Decimal var audioInputPrice float64 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens var cachedTokensWithRatio decimal.Decimal @@ -518,7 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, totalTokens := promptTokens + completionTokens var logContent string - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) @@ -530,8 +319,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { if !ratio.IsZero() && quota == 0 { quota = 1 @@ -540,11 +329,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -560,7 +349,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio @@ -604,7 +393,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, diff --git a/relay/relay_task.go b/relay/relay_task.go index 0ccc3b33..ae002d73 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -127,7 +128,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 1e547e2a..85e4f174 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) { +func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - var rerankRequest *dto.RerankRequest - err := common.UnmarshalBodyReusable(c, &rerankRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + rerankRequest, ok := info.Request.(*dto.RerankRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) - - if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) + err := helper.ModelMappedHelper(c, info, rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getRerankPromptToken(*rerankRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) + convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError requestBody = bytes.NewBuffer(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 65c240b2..cd80da33 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -3,7 +3,6 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -12,7 +11,6 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,82 +18,24 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { - request := &dto.OpenAIResponsesRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if request.Model == "" { - return nil, errors.New("model is required") - } - if len(request.Input) == 0 { - return nil, errors.New("input is required") - } - return request, nil +func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) -} - -func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) - return sensitiveWords, err -} - -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { - inputTokens := service.CountTokenInput(req.Input, req.Model) - info.PromptTokens = inputTokens - return inputTokens -} - -func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateResponsesRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + request, ok := info.Request.(*dto.OpenAIResponsesRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoResponses(c, req) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkInputSensitive(req, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) @@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) @@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } diff --git a/relay/websocket.go b/relay/websocket.go index 3715b237..22b681f1 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -15,13 +15,6 @@ import ( func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) - // get & validate textRequest 获取并验证文本请求 - //realtimeEvent, err := getAndValidateWssRequest(c, ws) - //if err != nil { - // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error())) - // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) - //} - err := helper.ModelMappedHelper(c, relayInfo, nil) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) diff --git a/router/main.go b/router/main.go index 0d2bfdce..7653f3a5 100644 --- a/router/main.go +++ b/router/main.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/logger" "os" "strings" ) @@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - common.SysLog("FRONTEND_BASE_URL is ignored on master node") + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS, indexPage) diff --git a/router/relay-router.go b/router/relay-router.go index cd656580..e0f05e97 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,11 +1,13 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/relay" + "one-api/types" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -62,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) { - // WebSocket 路由 + // WebSocket 路由(统一到 Relay) wsRouter := relayV1Router.Group("") wsRouter.Use(middleware.Distribute()) - wsRouter.GET("/realtime", controller.WssRelay) + wsRouter.GET("/realtime", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIRealtime) + }) } { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) - httpRouter.POST("/messages", controller.RelayClaude) - httpRouter.POST("/completions", controller.Relay) - httpRouter.POST("/chat/completions", controller.Relay) - httpRouter.POST("/edits", controller.Relay) - httpRouter.POST("/images/generations", controller.Relay) - httpRouter.POST("/images/edits", controller.Relay) + + // claude related routes + httpRouter.POST("/messages", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatClaude) + }) + + // chat related routes + httpRouter.POST("/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + httpRouter.POST("/chat/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // response related routes + httpRouter.POST("/responses", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIResponses) + }) + + // image related routes + httpRouter.POST("/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/generations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + + // embedding related routes + httpRouter.POST("/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatEmbedding) + }) + + // audio related routes + httpRouter.POST("/audio/transcriptions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/translations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/speech", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + + // rerank related routes + httpRouter.POST("/rerank", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatRerank) + }) + + // gemini relay routes + httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + httpRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + + // other relay routes + httpRouter.POST("/moderations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // not implemented httpRouter.POST("/images/variations", controller.RelayNotImplemented) - httpRouter.POST("/embeddings", controller.Relay) - httpRouter.POST("/engines/:model/embeddings", controller.Relay) - httpRouter.POST("/audio/transcriptions", controller.Relay) - httpRouter.POST("/audio/translations", controller.Relay) - httpRouter.POST("/audio/speech", controller.Relay) - httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -95,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) - httpRouter.POST("/moderations", controller.Relay) - httpRouter.POST("/rerank", controller.Relay) - httpRouter.POST("/models/*path", controller.Relay) } relayMjRouter := router.Group("/mj") @@ -121,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) { relayGeminiRouter.Use(middleware.Distribute()) { // Gemini API 路径格式: /v1beta/models/{model_name}:{action} - relayGeminiRouter.POST("/models/*path", controller.Relay) + relayGeminiRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) } } diff --git a/service/cf_worker.go b/service/cf_worker.go index ae6e1ffe..65f7f133 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "net/http" - "one-api/common" + "one-api/logger" "one-api/setting" "strings" ) @@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) } } diff --git a/service/error.go b/service/error.go index 9672402d..668731b0 100644 --- a/service/error.go +++ b/service/error.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "strconv" "strings" @@ -58,7 +59,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError lowerText := strings.ToLower(text) if !strings.HasPrefix(lowerText, "get file base64 from url") { if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) + logger.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } } @@ -85,7 +86,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if err != nil { return } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse err = common.Unmarshal(responseBody, &errResponse) @@ -138,7 +139,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { text := err.Error() lowerText := strings.ToLower(text) if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) + logger.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } //避免暴露内部错误 diff --git a/common/http.go b/service/http.go similarity index 86% rename from common/http.go rename to service/http.go index d2e824ef..357a2e78 100644 --- a/common/http.go +++ b/service/http.go @@ -1,10 +1,12 @@ -package common +package service import ( "bytes" "fmt" "io" "net/http" + "one-api/common" + "one-api/logger" "github.com/gin-gonic/gin" ) @@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) { } err := httpResponse.Body.Close() if err != nil { - SysError("failed to close response body: " + err.Error()) + common.SysError("failed to close response body: " + err.Error()) } } @@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { _, err := io.Copy(c.Writer, body) if err != nil { - LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) + logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) } } diff --git a/service/image.go b/service/image.go index 252093f1..957ca041 100644 --- a/service/image.go +++ b/service/image.go @@ -8,8 +8,8 @@ import ( "image" "io" "net/http" - "one-api/common" "one-api/constant" + "one-api/logger" "strings" "golang.org/x/image/webp" @@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { response, err := DoDownloadRequest(imageUrl) if err != nil { - common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err } defer response.Body.Close() @@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { - common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) + logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) // 从response.Body读取更多的数据直到达到当前的限制 additionalData := make([]byte, limit-int64(len(readData))) @@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) { config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - common.SysLog(err.Error()) + logger.SysLog(err.Error()) config, err = webp.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - common.SysLog(err.Error()) + logger.SysLog(err.Error()) } format = "webp" } diff --git a/service/midjourney.go b/service/midjourney.go index 1fc19682..1d232739 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relayconstant "one-api/relay/constant" "one-api/setting" "strconv" @@ -212,7 +213,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - common.SysError("do request failed: " + err.Error()) + logger.SysError("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode @@ -233,7 +234,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) respStr := string(responseBody) log.Printf("respStr: %s", respStr) if respStr == "" { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go new file mode 100644 index 00000000..3c4d0e7e --- /dev/null +++ b/service/pre_consume_quota.go @@ -0,0 +1,72 @@ +package service + +import ( + "errors" + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "net/http" + "one-api/logger" + "one-api/model" + relaycommon "one-api/relay/common" + "one-api/types" +) + +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { + if preConsumedQuota != 0 { + gopool.Go(func() { + relayInfoCopy := *relayInfo + + err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + if err != nil { + logger.SysError("error return pre-consumed quota: " + err.Error()) + } + }) + } +} + +// PreConsumeQuota checks if the user has enough quota to pre-consume. +// It returns the pre-consumed quota if successful, or an error if not. +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + } + if userQuota <= 0 { + return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if userQuota-preConsumedQuota < 0 { + return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + relayInfo.UserQuota = userQuota + if userQuota > 100*preConsumedQuota { + // 用户额度充足,判断令牌额度是否充足 + if !relayInfo.TokenUnlimited { + // 非无限令牌,判断令牌额度是否充足 + tokenQuota := c.GetInt("token_quota") + if tokenQuota > 100*preConsumedQuota { + // 令牌额度充足,信任令牌 + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) + } + } else { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota))) + } + } + + if preConsumedQuota > 0 { + err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) + if err != nil { + return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + } + relayInfo.FinalPreConsumedQuota = preConsumedQuota + return preConsumedQuota, nil +} diff --git a/service/quota.go b/service/quota.go index 0f618402..d6f49d64 100644 --- a/service/quota.go +++ b/service/quota.go @@ -8,11 +8,12 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strings" "time" @@ -129,23 +130,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota := calculateAudioQuota(quotaInfo) if userQuota < quota { - return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)) + return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota)) } if !token.UnlimitedQuota && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } - common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) + logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) return nil } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + usage *dto.RealtimeUsage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -159,10 +160,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -196,8 +197,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) @@ -208,7 +209,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.InputTokens, @@ -218,7 +219,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -226,8 +226,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod }) } -func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -235,20 +234,20 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - cacheRatio := priceData.CacheRatio + completionRatio := relayInfo.PriceData.CompletionRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + cacheRatio := relayInfo.PriceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens - cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { promptTokens -= cacheTokens - if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 { - maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) + if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 { + maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData) if promptTokens >= maybeCacheCreationTokens { cacheCreationTokens = maybeCacheCreationTokens } @@ -257,7 +256,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } calculateQuota := 0.0 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { calculateQuota = float64(promptTokens) calculateQuota += float64(cacheTokens) * cacheRatio calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio @@ -282,23 +281,23 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游出错)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, @@ -308,7 +307,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -317,7 +315,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } -func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { +func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int { if priceData.CacheCreationRatio == 1 { return 0 } @@ -338,8 +336,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData (promptCacheCreatePrice - quotaPrice))) } -func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -353,10 +350,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -390,18 +387,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -410,7 +407,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.PromptTokens, @@ -420,7 +417,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -443,7 +439,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { return err } if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) if err != nil { @@ -501,7 +497,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) if err != nil { common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } diff --git a/service/token_counter.go b/service/token_counter.go index eed5b5ca..ec817182 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,18 +4,22 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tiktoken-go/tokenizer" - "github.com/tiktoken-go/tokenizer/codec" "image" "log" "math" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" + "one-api/types" "strings" "sync" "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" ) // tokenEncoderMap won't grow after initialization @@ -28,9 +32,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec) var tokenEncoderMutex sync.RWMutex func InitTokenEncoders() { - common.SysLog("initializing token encoders") + logger.SysLog("initializing token encoders") defaultTokenEncoder = codec.NewCl100kBase() - common.SysLog("token encoders initialized") + logger.SysLog("token encoders initialized") } func getTokenEncoder(model string) tokenizer.Codec { @@ -72,52 +76,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { return tkm } -func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - if imageUrl == nil { +func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { + if fileMeta == nil { return 0, fmt.Errorf("image_url_is_nil") } + + // Defaults for 4o/4.1/4.5 family unless overridden below baseTokens := 85 - if model == "glm-4v" { + tileTokens := 170 + + // Model classification + lowerModel := strings.ToLower(model) + + // Special cases from existing behavior + if strings.HasPrefix(lowerModel, "glm-4") { return 1047, nil } - if imageUrl.Detail == "low" { + + // Patch-based models (32x32 patches, capped at 1536, with multiplier) + isPatchBased := false + multiplier := 1.0 + switch { + case strings.Contains(lowerModel, "gpt-4.1-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.Contains(lowerModel, "gpt-4.1-nano"): + isPatchBased = true + multiplier = 2.46 + case strings.HasPrefix(lowerModel, "o4-mini"): + isPatchBased = true + multiplier = 1.72 + case strings.HasPrefix(lowerModel, "gpt-5-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.HasPrefix(lowerModel, "gpt-5-nano"): + isPatchBased = true + multiplier = 2.46 + } + + // Tile-based model tokens and bases per doc + if !isPatchBased { + if strings.HasPrefix(lowerModel, "gpt-4o-mini") { + baseTokens = 2833 + tileTokens = 5667 + } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) { + baseTokens = 70 + tileTokens = 140 + } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") { + baseTokens = 75 + tileTokens = 150 + } else if strings.Contains(lowerModel, "computer-use-preview") { + baseTokens = 65 + tileTokens = 129 + } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") { + baseTokens = 85 + tileTokens = 170 + } + } + + // Respect existing feature flags/short-circuits + if fileMeta.Detail == "low" && !isPatchBased { return baseTokens, nil } if !constant.GetMediaTokenNotStream && !stream { return 3 * baseTokens, nil } - - // 同步One API的图片计费逻辑 - if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - imageUrl.Detail = "high" + // Normalize detail + if fileMeta.Detail == "auto" || fileMeta.Detail == "" { + fileMeta.Detail = "high" } - - tileTokens := 170 - if strings.HasPrefix(model, "gpt-4o-mini") { - tileTokens = 5667 - baseTokens = 2833 - } - // 是否统计图片token + // Whether to count image tokens at all if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { - return 3 * baseTokens, nil - } + + // Decode image to get dimensions var config image.Config var err error var format string var b64str string - if strings.HasPrefix(imageUrl.Url, "http") { - config, format, err = DecodeUrlImageData(imageUrl.Url) + if strings.HasPrefix(fileMeta.Data, "http") { + config, format, err = DecodeUrlImageData(fileMeta.Data) } else { - common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) + logger.SysLog(fmt.Sprintf("decoding image")) + config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) } if err != nil { return 0, err } - imageUrl.MimeType = format + fileMeta.MimeType = format if config.Width == 0 || config.Height == 0 { // not an image @@ -125,60 +172,144 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data)) } - shortSide := config.Width - otherSide := config.Height - log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height) - // 缩放倍数 - scale := 1.0 - if config.Height < shortSide { - shortSide = config.Height - otherSide = config.Width + width := config.Width + height := config.Height + log.Printf("format: %s, width: %d, height: %d", format, width, height) + + if isPatchBased { + // 32x32 patch-based calculation with 1536 cap and model multiplier + ceilDiv := func(a, b int) int { return (a + b - 1) / b } + rawPatchesW := ceilDiv(width, 32) + rawPatchesH := ceilDiv(height, 32) + rawPatches := rawPatchesW * rawPatchesH + if rawPatches > 1536 { + // scale down + area := float64(width * height) + r := math.Sqrt(float64(32*32*1536) / area) + wScaled := float64(width) * r + hScaled := float64(height) * r + // adjust to fit whole number of patches after scaling + adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0) + adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0) + adj := math.Min(adjW, adjH) + if !math.IsNaN(adj) && adj > 0 { + r = r * adj + } + wScaled = float64(width) * r + hScaled = float64(height) * r + patchesW := math.Ceil(wScaled / 32.0) + patchesH := math.Ceil(hScaled / 32.0) + imageTokens := int(patchesW * patchesH) + if imageTokens > 1536 { + imageTokens = 1536 + } + return int(math.Round(float64(imageTokens) * multiplier)), nil + } + // below cap + imageTokens := rawPatches + return int(math.Round(float64(imageTokens) * multiplier)), nil } - // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 - if shortSide > 768 { - scale = float64(shortSide) / 768 - shortSide = 768 + // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc. + // Step 1: fit within 2048x2048 square + maxSide := math.Max(float64(width), float64(height)) + fitScale := 1.0 + if maxSide > 2048 { + fitScale = maxSide / 2048.0 } - // 将另一边按照相同的比例缩小,向上取整 - otherSide = int(math.Ceil(float64(otherSide) / scale)) - log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) - // 计算图片的token数量(边的长度除以512,向上取整) - tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) - log.Printf("tiles: %d", tiles) + fitW := int(math.Round(float64(width) / fitScale)) + fitH := int(math.Round(float64(height) / fitScale)) + + // Step 2: scale so that shortest side is exactly 768 + minSide := math.Min(float64(fitW), float64(fitH)) + if minSide == 0 { + return baseTokens, nil + } + shortScale := 768.0 / minSide + finalW := int(math.Round(float64(fitW) * shortScale)) + finalH := int(math.Round(float64(fitH) * shortScale)) + + // Count 512px tiles + tilesW := (finalW + 512 - 1) / 512 + tilesH := (finalH + 512 - 1) / 512 + tiles := tilesW * tilesH + + if common.DebugEnabled { + log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles) + } + return tiles*tileTokens + baseTokens, nil } -func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { - tkm := 0 - msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) - if err != nil { - return 0, err +func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + if meta == nil { + return 0, errors.New("token count meta is nil") } - tkm += msgTokens - if request.Tools != nil { - openaiTools := request.Tools - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) - } - } - toolTokens := CountTokenInput(countStr, request.Model) - tkm += 8 - tkm += toolTokens + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + tkm := CountTextToken(meta.CombineText, model) + + if info.RelayFormat == types.RelayFormatOpenAI { + tkm += meta.ToolsCount * 8 + tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量 + tkm += meta.NameCount * 3 + tkm += 3 } + for _, file := range meta.Files { + switch file.FileType { + case types.FileTypeImage: + if info.RelayFormat == types.RelayFormatGemini { + tkm += 240 + } else { + token, err := getImageToken(file, model, info.IsStream) + if err != nil { + return 0, fmt.Errorf("error counting image token: %v", err) + } + tkm += token + } + case types.FileTypeAudio: + tkm += 100 + case types.FileTypeVideo: + tkm += 5000 + case types.FileTypeFile: + tkm += 5000 + } + } + + common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm) return tkm, nil } +//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { +// tkm := 0 +// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) +// if err != nil { +// return 0, err +// } +// tkm += msgTokens +// if request.Tools != nil { +// openaiTools := request.Tools +// countStr := "" +// for _, tool := range openaiTools { +// countStr = tool.Function.Name +// if tool.Function.Description != "" { +// countStr += tool.Function.Description +// } +// if tool.Function.Parameters != nil { +// countStr += fmt.Sprintf("%v", tool.Function.Parameters) +// } +// } +// toolTokens := CountTokenInput(countStr, request.Model) +// tkm += 8 +// tkm += toolTokens +// } +// +// return tkm, nil +//} + func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { tkm := 0 @@ -338,58 +469,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { - //recover when panic - tokenEncoder := getTokenEncoder(model) - // Reference: - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - // - // Every message follows <|start|>{role/name}\n{content}<|end|>\n - var tokensPerMessage int - var tokensPerName int - if model == "gpt-3.5-turbo-0301" { - tokensPerMessage = 4 - tokensPerName = -1 // If there's a name, the role is omitted - } else { - tokensPerMessage = 3 - tokensPerName = 1 - } - tokenNum := 0 - for _, message := range messages { - tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Content != nil { - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.GetImageMedia() - imageTokenNum, err := getImageToken(info, imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else if m.Type == dto.ContentTypeFile { - tokenNum += 5000 - } else if m.Type == dto.ContentTypeVideoUrl { - tokenNum += 5000 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) - } - } - } - } - tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil -} +//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { +// //recover when panic +// tokenEncoder := getTokenEncoder(model) +// // Reference: +// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +// // https://github.com/pkoukk/tiktoken-go/issues/6 +// // +// // Every message follows <|start|>{role/name}\n{content}<|end|>\n +// var tokensPerMessage int +// var tokensPerName int +// +// tokensPerMessage = 3 +// tokensPerName = 1 +// +// tokenNum := 0 +// for _, message := range messages { +// tokenNum += tokensPerMessage +// tokenNum += getTokenNum(tokenEncoder, message.Role) +// if message.Content != nil { +// if message.Name != nil { +// tokenNum += tokensPerName +// tokenNum += getTokenNum(tokenEncoder, *message.Name) +// } +// arrayContent := message.ParseContent() +// for _, m := range arrayContent { +// if m.Type == dto.ContentTypeImageURL { +// imageUrl := m.GetImageMedia() +// imageTokenNum, err := getImageToken(info, imageUrl, model, stream) +// if err != nil { +// return 0, err +// } +// tokenNum += imageTokenNum +// log.Printf("image token num: %d", imageTokenNum) +// } else if m.Type == dto.ContentTypeInputAudio { +// // TODO: 音频token数量计算 +// tokenNum += 100 +// } else if m.Type == dto.ContentTypeFile { +// tokenNum += 5000 +// } else if m.Type == dto.ContentTypeVideoUrl { +// tokenNum += 5000 +// } else { +// tokenNum += getTokenNum(tokenEncoder, m.Text) +// } +// } +// } +// } +// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> +// return tokenNum, nil +//} func CountTokenInput(input any, model string) int { switch v := input.(type) { diff --git a/service/user_notify.go b/service/user_notify.go index 96664007..1fcc62d3 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -4,6 +4,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "strings" ) @@ -12,7 +13,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -25,7 +26,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -37,14 +38,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // check setting email userEmail = userSetting.NotificationEmail if userEmail == "" { - common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) + logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } diff --git a/setting/chat.go b/setting/chat.go index b97d65ce..b417af28 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/common" + "one-api/logger" ) var Chats = []map[string]string{ @@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - common.SysError("error marshalling chats: " + err.Error()) + logger.SysError("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/config/config.go b/setting/config/config.go index 3af51b14..2e43e0a7 100644 --- a/setting/config/config.go +++ b/setting/config/config.go @@ -2,7 +2,7 @@ package config import ( "encoding/json" - "one-api/common" + "one-api/logger" "reflect" "strconv" "strings" @@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error { // 如果找到配置项,则更新配置 if len(configMap) > 0 { if err := updateConfigFromMap(config, configMap); err != nil { - common.SysError("failed to update config " + name + ": " + err.Error()) + logger.SysError("failed to update config " + name + ": " + err.Error()) continue } } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index d550b2c3..dcb9fae5 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "math" - "one-api/common" + "one-api/logger" "sync" ) @@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 3f223bc3..47079850 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -2,7 +2,7 @@ package ratio_setting import ( "encoding/json" - "one-api/common" + "one-api/logger" "sync" ) @@ -89,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - common.SysError("error marshalling cache ratio: " + err.Error()) + logger.SysError("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index 86f4a8d1..c1a666e9 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -3,7 +3,7 @@ package ratio_setting import ( "encoding/json" "errors" - "one-api/common" + "one-api/logger" "sync" ) @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - common.SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - common.SysError("error marshalling group-group ratio: " + err.Error()) + logger.SysError("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 4a19895e..ce822800 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -320,7 +320,7 @@ func ModelPrice2JSONString() string { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelPriceMap) + jsonBytes, err := common.Marshal(modelPriceMap) if err != nil { common.SysError("error marshalling model price: " + err.Error()) } @@ -359,7 +359,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := common.Unmarshal([]byte(jsonStr), &modelRatioMap) if err == nil { InvalidateExposedDataCache() } @@ -388,7 +388,7 @@ func GetModelRatio(name string) (float64, bool, string) { } func DefaultModelRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultModelRatio) + jsonBytes, err := common.Marshal(defaultModelRatio) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -420,7 +420,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := common.Unmarshal([]byte(jsonStr), &CompletionRatio) if err == nil { InvalidateExposedDataCache() } @@ -594,7 +594,7 @@ func ModelRatio2JSONString() string { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelRatioMap) + jsonBytes, err := common.Marshal(modelRatioMap) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -610,7 +610,7 @@ var imageRatioMapMutex sync.RWMutex func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() defer imageRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(imageRatioMap) + jsonBytes, err := common.Marshal(imageRatioMap) if err != nil { common.SysError("error marshalling cache ratio: " + err.Error()) } @@ -621,7 +621,7 @@ func UpdateImageRatioByJSONString(jsonStr string) error { imageRatioMapMutex.Lock() defer imageRatioMapMutex.Unlock() imageRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &imageRatioMap) + return common.Unmarshal([]byte(jsonStr), &imageRatioMap) } func GetImageRatio(name string) (float64, bool) { diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 0ae132d0..bcbe712c 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/common" + "one-api/logger" "sync" ) @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - common.SysError("error marshalling user groups: " + err.Error()) + logger.SysError("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/error.go b/types/error.go index 5a143612..2cfeb541 100644 --- a/types/error.go +++ b/types/error.go @@ -39,12 +39,13 @@ const ( ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" // new api error - ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" - ErrorCodeModelPriceError ErrorCode = "model_price_error" - ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" - ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" - ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" - ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" + ErrorCodeModelPriceError ErrorCode = "model_price_error" + ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" + ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" + ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" + ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" // channel error ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" diff --git a/types/price_data.go b/types/price_data.go new file mode 100644 index 00000000..f6a92d7e --- /dev/null +++ b/types/price_data.go @@ -0,0 +1,31 @@ +package types + +import "fmt" + +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 + HasSpecialRatio bool +} + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + CompletionRatio float64 + CacheRatio float64 + CacheCreationRatio float64 + ImageRatio float64 + UsePrice bool + ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo +} + +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +func (p PriceData) ToSetting() string { + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} diff --git a/types/relay_format.go b/types/relay_format.go new file mode 100644 index 00000000..4c29d649 --- /dev/null +++ b/types/relay_format.go @@ -0,0 +1,15 @@ +package types + +type RelayFormat string + +const ( + RelayFormatOpenAI RelayFormat = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatOpenAIRealtime = "openai_realtime" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" +) diff --git a/types/relay_request.go b/types/relay_request.go new file mode 100644 index 00000000..b9d092f0 --- /dev/null +++ b/types/relay_request.go @@ -0,0 +1,27 @@ +package types + +type RelayRequest struct { + OriginRequest any + Format RelayFormat + PromptTokenCount int +} + +func (r *RelayRequest) CopyOriginRequest() any { + if r.OriginRequest == nil { + return nil + } + switch v := r.OriginRequest.(type) { + case *GeneralOpenAIRequest: + return v.Copy() + case *GeneralClaudeRequest: + return v.Copy() + case *GeneralGeminiRequest: + return v.Copy() + case *GeneralRerankRequest: + return v.Copy() + case *GeneralEmbeddingRequest: + return v.Copy() + default: + return nil + } +} diff --git a/types/request_meta.go b/types/request_meta.go new file mode 100644 index 00000000..427bacb9 --- /dev/null +++ b/types/request_meta.go @@ -0,0 +1,45 @@ +package types + +type FileType string + +const ( + FileTypeImage FileType = "image" // Image file type + FileTypeAudio FileType = "audio" // Audio file type + FileTypeVideo FileType = "video" // Video file type + FileTypeFile FileType = "file" // Generic file type +) + +type TokenType string + +const ( + TokenTypeTextNumber TokenType = "text_number" // Text or number tokens + TokenTypeTokenizer TokenType = "tokenizer" // Tokenizer tokens + TokenTypeImage TokenType = "image" // Image tokens +) + +type TokenCountMeta struct { + TokenType TokenType `json:"token_type,omitempty"` // Type of tokens used in the request + CombineText string `json:"combine_text,omitempty"` // Combined text from all messages + ToolsCount int `json:"tools_count,omitempty"` // Number of tools used + NameCount int `json:"name_count,omitempty"` // Number of names in the request + MessagesCount int `json:"messages_count,omitempty"` // Number of messages in the request + Files []*FileMeta `json:"files,omitempty"` // List of files, each with type and content + MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens allowed in the request + + ImagePriceRatio float64 `json:"image_ratio,omitempty"` // Ratio for image size, if applicable + //IsStreaming bool `json:"is_streaming,omitempty"` // Indicates if the request is streaming +} + +type FileMeta struct { + FileType + MimeType string + Data string + Detail string +} + +type RequestMeta struct { + OriginalModelName string `json:"original_model_name"` + UserUsingGroup string `json:"user_using_group"` + PromptTokens int `json:"prompt_tokens"` + PreConsumedQuota int `json:"pre_consumed_quota"` +}