diff --git a/common/quota.go b/common/quota.go new file mode 100644 index 00000000..dfd65d27 --- /dev/null +++ b/common/quota.go @@ -0,0 +1,5 @@ +package common + +func GetTrustQuota() int { + return int(10 * QuotaPerUnit) +} diff --git a/common/str.go b/common/str.go index f5399eab..6debce28 100644 --- a/common/str.go +++ b/common/str.go @@ -99,12 +99,75 @@ func GetJsonString(data any) string { return string(b) } -// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// MaskEmail masks a user email to prevent PII leakage in logs +// Returns "***masked***" if email is empty, otherwise shows only the domain part +func MaskEmail(email string) string { + if email == "" { + return "***masked***" + } + + // Find the @ symbol + atIndex := strings.Index(email, "@") + if atIndex == -1 { + // No @ symbol found, return masked + return "***masked***" + } + + // Return only the domain part with @ symbol + return "***@" + email[atIndex+1:] +} + +// maskHostTail returns the tail parts of a domain/host that should be preserved. +// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD. +func maskHostTail(parts []string) []string { + if len(parts) < 2 { + return parts + } + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + return []string{secondLastPart, lastPart} + } + return []string{lastPart} +} + +// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail. +// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk +func maskHostForURL(host string) string { + parts := strings.Split(host, ".") + if len(parts) < 2 { + return "***" + } + tail := maskHostTail(parts) + return "***." + strings.Join(tail, ".") +} + +// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***. +// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk +func maskHostForPlainDomain(domain string) string { + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + tail := maskHostTail(parts) + numStars := len(parts) - len(tail) + if numStars < 1 { + numStars = 1 + } + stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".") + return stars + "." + strings.Join(tail, ".") +} + +// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** // 192.168.1.1 -> ***.***.***.*** +// openai.com -> ***.com +// www.openai.com -> ***.***.com +// api.openai.com -> ***.***.com func MaskSensitiveInfo(str string) string { // Mask URLs urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) @@ -119,32 +182,8 @@ func MaskSensitiveInfo(str string) string { return urlStr } - // Split host by dots - parts := strings.Split(host, ".") - if len(parts) < 2 { - // If less than 2 parts, just mask the whole host - return u.Scheme + "://***" + u.Path - } - - // Keep the TLD (Top Level Domain) and mask the rest - var maskedHost string - if len(parts) == 2 { - // example.com -> ***.com - maskedHost = "***." + parts[len(parts)-1] - } else { - // Handle cases like sub.domain.co.uk or api.example.com - // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) - lastPart := parts[len(parts)-1] - secondLastPart := parts[len(parts)-2] - - if len(lastPart) == 2 && len(secondLastPart) <= 3 { - // Likely country code TLD like co.uk, com.cn - maskedHost = "***." + secondLastPart + "." + lastPart - } else { - // Regular TLD like .com, .org - maskedHost = "***." + lastPart - } - } + // Mask host with unified logic + maskedHost := maskHostForURL(host) result := u.Scheme + "://" + maskedHost @@ -184,6 +223,12 @@ func MaskSensitiveInfo(str string) string { return result }) + // Mask domain names without protocol (like openai.com, www.openai.com) + domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) + str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { + return maskHostForPlainDomain(domain) + }) + // Mask IP addresses ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) str = ipPattern.ReplaceAllString(str, "***.***.***.***") diff --git a/common/sys_log.go b/common/sys_log.go new file mode 100644 index 00000000..478015f0 --- /dev/null +++ b/common/sys_log.go @@ -0,0 +1,24 @@ +package common + +import ( + "fmt" + "github.com/gin-gonic/gin" + "os" + "time" +) + +func SysLog(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func SysError(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func FatalLog(v ...any) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) + os.Exit(1) +} diff --git a/constant/context_key.go b/constant/context_key.go index b82b19e7..569a0373 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,6 +3,8 @@ package constant type ContextKey string const ( + ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" diff --git a/controller/channel-test.go b/controller/channel-test.go index 026a863b..ea37c5bf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -132,10 +132,27 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } + request := buildTestRequest(testModel) - info := relaycommon.GenRelayInfo(c) + // Determine relay format based on request path + relayFormat := types.RelayFormatOpenAI + if c.Request.URL.Path == "/v1/embeddings" { + relayFormat = types.RelayFormatEmbedding + } - err = helper.ModelMappedHelper(c, info, nil) + info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) + + if err != nil { + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), + } + } + + info.InitChannelMeta(c) + + err = helper.ModelMappedHelper(c, info, request) if err != nil { return testResult{ context: c, @@ -143,7 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), } } + testModel = info.UpstreamModelName + request.Model = testModel apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -155,13 +174,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - request := buildTestRequest(testModel) - // 创建一个用于日志的 info 副本,移除 ApiKey - logInfo := *info - logInfo.ApiKey = "" - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + //// 创建一个用于日志的 info 副本,移除 ApiKey + //logInfo := info + //logInfo.ApiKey = "" + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString())) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) + priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) if err != nil { return testResult{ context: c, diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d25f199b..f0812c3d 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -3,101 +3,102 @@ package controller import ( - "encoding/json" - "net/http" - "one-api/common" - "one-api/model" - "github.com/gin-gonic/gin" + "encoding/json" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* func MigrateConsoleSetting(c *gin.Context) { - // 读取全部 option - opts, err := model.AllOption() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) - return - } - // 建立 map - valMap := map[string]string{} - for _, o := range opts { - valMap[o.Key] = o.Value - } + // 读取全部 option + opts, err := model.AllOption() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + // 建立 map + valMap := map[string]string{} + for _, o := range opts { + valMap[o.Key] = o.Value + } - // 处理 APIInfo - if v := valMap["ApiInfo"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - if len(arr) > 50 { - arr = arr[:50] - } - bytes, _ := json.Marshal(arr) - model.UpdateOption("console_setting.api_info", string(bytes)) - } - model.UpdateOption("ApiInfo", "") - } - // Announcements 直接搬 - if v := valMap["Announcements"]; v != "" { - model.UpdateOption("console_setting.announcements", v) - model.UpdateOption("Announcements", "") - } - // FAQ 转换 - if v := valMap["FAQ"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - out := []map[string]interface{}{} - for _, item := range arr { - q, _ := item["question"].(string) - if q == "" { - q, _ = item["title"].(string) - } - a, _ := item["answer"].(string) - if a == "" { - a, _ = item["content"].(string) - } - if q != "" && a != "" { - out = append(out, map[string]interface{}{"question": q, "answer": a}) - } - } - if len(out) > 50 { - out = out[:50] - } - bytes, _ := json.Marshal(out) - model.UpdateOption("console_setting.faq", string(bytes)) - } - model.UpdateOption("FAQ", "") - } - // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) - url := valMap["UptimeKumaUrl"] - slug := valMap["UptimeKumaSlug"] - if url != "" && slug != "" { - // 仅当同时存在 URL 与 Slug 时才进行迁移 - groups := []map[string]interface{}{ - { - "id": 1, - "categoryName": "old", - "url": url, - "slug": slug, - "description": "", - }, - } - bytes, _ := json.Marshal(groups) - model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) - } - // 清空旧键内容 - if url != "" { - model.UpdateOption("UptimeKumaUrl", "") - } - if slug != "" { - model.UpdateOption("UptimeKumaSlug", "") - } + // 处理 APIInfo + if v := valMap["ApiInfo"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + if len(arr) > 50 { + arr = arr[:50] + } + bytes, _ := json.Marshal(arr) + model.UpdateOption("console_setting.api_info", string(bytes)) + } + model.UpdateOption("ApiInfo", "") + } + // Announcements 直接搬 + if v := valMap["Announcements"]; v != "" { + model.UpdateOption("console_setting.announcements", v) + model.UpdateOption("Announcements", "") + } + // FAQ 转换 + if v := valMap["FAQ"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + out := []map[string]interface{}{} + for _, item := range arr { + q, _ := item["question"].(string) + if q == "" { + q, _ = item["title"].(string) + } + a, _ := item["answer"].(string) + if a == "" { + a, _ = item["content"].(string) + } + if q != "" && a != "" { + out = append(out, map[string]interface{}{"question": q, "answer": a}) + } + } + if len(out) > 50 { + out = out[:50] + } + bytes, _ := json.Marshal(out) + model.UpdateOption("console_setting.faq", string(bytes)) + } + model.UpdateOption("FAQ", "") + } + // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) + url := valMap["UptimeKumaUrl"] + slug := valMap["UptimeKumaSlug"] + if url != "" && slug != "" { + // 仅当同时存在 URL 与 Slug 时才进行迁移 + groups := []map[string]interface{}{ + { + "id": 1, + "categoryName": "old", + "url": url, + "slug": slug, + "description": "", + }, + } + bytes, _ := json.Marshal(groups) + model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) + } + // 清空旧键内容 + if url != "" { + model.UpdateOption("UptimeKumaUrl", "") + } + if slug != "" { + model.UpdateOption("UptimeKumaSlug", "") + } - // 删除旧键记录 - oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} - model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) + // 删除旧键记录 + oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} + model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) - // 重新加载 OptionMap - model.InitOptionMap() - common.SysLog("console setting migrated") - c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) -} \ No newline at end of file + // 重新加载 OptionMap + model.InitOptionMap() + common.SysLog("console setting migrated") + c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) +} diff --git a/controller/midjourney.go b/controller/midjourney.go index 30a5a09a..a67d39c2 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() { continue } - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) @@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() { } for channelId, taskIds := range taskChannelM { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } midjourneyChannel, err := model.CacheGetChannel(channelId) if err != nil { - common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) + logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) err := model.MjBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) } continue } @@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() { }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 @@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() { req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := service.GetHttpClient().Do(req) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -147,12 +148,12 @@ func UpdateMidjourneyTaskBulk() { } // 映射 VideoUrl task.VideoUrl = responseItem.VideoUrl - + // 映射 VideoUrls - 将数组序列化为 JSON 字符串 if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) if err != nil { - common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) + logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) task.VideoUrls = "[]" // 失败时设置为空数组 } else { task.VideoUrls = string(videoUrlsStr) @@ -160,10 +161,10 @@ func UpdateMidjourneyTaskBulk() { } else { task.VideoUrls = "" // 空值时清空字段 } - + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { - common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" if task.Quota != 0 { shouldReturnQuota = true @@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() { } err = task.Update() if err != nil { - common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } diff --git a/controller/model.go b/controller/model.go index d03fdeb2..398503e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -93,7 +93,9 @@ func init() { if !success || apiType == constant.APITypeAIProxyLibrary { continue } - meta := &relaycommon.RelayInfo{ChannelType: i} + meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{ + ChannelType: i, + }} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() diff --git a/controller/oidc.go b/controller/oidc.go index df8ea1c4..f3def0e3 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - common.SysError("OIDC 获取 Token 失败,请检查设置!") + common.SysLog("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - common.SysError("OIDC 获取用户信息失败!请检查设置!") + common.SysLog("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - common.SysError("OIDC 获取用户信息为空!请检查设置!") + common.SysLog("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/playground.go b/controller/playground.go index dd930802..8a1cb2b6 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -56,5 +56,5 @@ func Playground(c *gin.Context) { //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) - Relay(c) + Relay(c, types.RelayFormatOpenAI) } diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 0453870d..6fba0aac 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,474 +1,474 @@ package controller import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - "time" + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/logger" + "strings" + "sync" + "time" - "one-api/common" - "one-api/dto" - "one-api/model" - "one-api/setting/ratio_setting" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const ( - defaultTimeoutSeconds = 10 - defaultEndpoint = "/api/ratio_config" - maxConcurrentFetches = 8 + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 ) var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} type upstreamResult struct { - Name string `json:"name"` - Data map[string]any `json:"data,omitempty"` - Err string `json:"err,omitempty"` + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` } func FetchUpstreamRatios(c *gin.Context) { - var req dto.UpstreamRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) - return - } + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } - if req.Timeout <= 0 { - req.Timeout = defaultTimeoutSeconds - } + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } - var upstreams []dto.UpstreamDTO + var upstreams []dto.UpstreamDTO - if len(req.Upstreams) > 0 { - for _, u := range req.Upstreams { - if strings.HasPrefix(u.BaseURL, "http") { - if u.Endpoint == "" { - u.Endpoint = defaultEndpoint - } - u.BaseURL = strings.TrimRight(u.BaseURL, "/") - upstreams = append(upstreams, u) - } - } - } else if len(req.ChannelIDs) > 0 { - intIds := make([]int, 0, len(req.ChannelIDs)) - for _, id64 := range req.ChannelIDs { - intIds = append(intIds, int(id64)) - } - dbChannels, err := model.GetChannelsByIds(intIds) - if err != nil { - common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) - return - } - for _, ch := range dbChannels { - if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { - upstreams = append(upstreams, dto.UpstreamDTO{ - ID: ch.Id, - Name: ch.Name, - BaseURL: strings.TrimRight(base, "/"), - Endpoint: "", - }) - } - } - } + if len(req.Upstreams) > 0 { + for _, u := range req.Upstreams { + if strings.HasPrefix(u.BaseURL, "http") { + if u.Endpoint == "" { + u.Endpoint = defaultEndpoint + } + u.BaseURL = strings.TrimRight(u.BaseURL, "/") + upstreams = append(upstreams, u) + } + } + } else if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + ID: ch.Id, + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } - if len(upstreams) == 0 { - c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) - return - } + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } - var wg sync.WaitGroup - ch := make(chan upstreamResult, len(upstreams)) + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) - sem := make(chan struct{}, maxConcurrentFetches) + sem := make(chan struct{}, maxConcurrentFetches) - client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} - for _, chn := range upstreams { - wg.Add(1) - go func(chItem dto.UpstreamDTO) { - defer wg.Done() + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() - sem <- struct{}{} - defer func() { <-sem }() + sem <- struct{}{} + defer func() { <-sem }() - endpoint := chItem.Endpoint - if endpoint == "" { - endpoint = defaultEndpoint - } else if !strings.HasPrefix(endpoint, "/") { - endpoint = "/" + endpoint - } - fullURL := chItem.BaseURL + endpoint + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL := chItem.BaseURL + endpoint - uniqueName := chItem.Name - if chItem.ID != 0 { - uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) - } + uniqueName := chItem.Name + if chItem.ID != 0 { + uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) + } - ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) - if err != nil { - common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - resp, err := client.Do(httpReq) - if err != nil { - common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) - ch <- upstreamResult{Name: uniqueName, Err: resp.Status} - return - } - // 兼容两种上游接口格式: - // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price - // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 - var body struct { - Success bool `json:"success"` - Data json.RawMessage `json:"data"` - Message string `json:"message"` - } + resp, err := client.Do(httpReq) + if err != nil { + logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) + ch <- upstreamResult{Name: uniqueName, Err: resp.Status} + return + } + // 兼容两种上游接口格式: + // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price + // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 + var body struct { + Success bool `json:"success"` + Data json.RawMessage `json:"data"` + Message string `json:"message"` + } - if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { - common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - if !body.Success { - ch <- upstreamResult{Name: uniqueName, Err: body.Message} - return - } + if !body.Success { + ch <- upstreamResult{Name: uniqueName, Err: body.Message} + return + } - // 尝试按 type1 解析 - var type1Data map[string]any - if err := json.Unmarshal(body.Data, &type1Data); err == nil { - // 如果包含至少一个 ratioTypes 字段,则认为是 type1 - isType1 := false - for _, rt := range ratioTypes { - if _, ok := type1Data[rt]; ok { - isType1 = true - break - } - } - if isType1 { - ch <- upstreamResult{Name: uniqueName, Data: type1Data} - return - } - } + // 尝试按 type1 解析 + var type1Data map[string]any + if err := json.Unmarshal(body.Data, &type1Data); err == nil { + // 如果包含至少一个 ratioTypes 字段,则认为是 type1 + isType1 := false + for _, rt := range ratioTypes { + if _, ok := type1Data[rt]; ok { + isType1 = true + break + } + } + if isType1 { + ch <- upstreamResult{Name: uniqueName, Data: type1Data} + return + } + } - // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 - var pricingItems []struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - CompletionRatio float64 `json:"completion_ratio"` - } - if err := json.Unmarshal(body.Data, &pricingItems); err != nil { - common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} - return - } + // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 + var pricingItems []struct { + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + CompletionRatio float64 `json:"completion_ratio"` + } + if err := json.Unmarshal(body.Data, &pricingItems); err != nil { + logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} + return + } - modelRatioMap := make(map[string]float64) - completionRatioMap := make(map[string]float64) - modelPriceMap := make(map[string]float64) + modelRatioMap := make(map[string]float64) + completionRatioMap := make(map[string]float64) + modelPriceMap := make(map[string]float64) - for _, item := range pricingItems { - if item.QuotaType == 1 { - modelPriceMap[item.ModelName] = item.ModelPrice - } else { - modelRatioMap[item.ModelName] = item.ModelRatio - // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 - completionRatioMap[item.ModelName] = item.CompletionRatio - } - } + for _, item := range pricingItems { + if item.QuotaType == 1 { + modelPriceMap[item.ModelName] = item.ModelPrice + } else { + modelRatioMap[item.ModelName] = item.ModelRatio + // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 + completionRatioMap[item.ModelName] = item.CompletionRatio + } + } - converted := make(map[string]any) + converted := make(map[string]any) - if len(modelRatioMap) > 0 { - ratioAny := make(map[string]any, len(modelRatioMap)) - for k, v := range modelRatioMap { - ratioAny[k] = v - } - converted["model_ratio"] = ratioAny - } + if len(modelRatioMap) > 0 { + ratioAny := make(map[string]any, len(modelRatioMap)) + for k, v := range modelRatioMap { + ratioAny[k] = v + } + converted["model_ratio"] = ratioAny + } - if len(completionRatioMap) > 0 { - compAny := make(map[string]any, len(completionRatioMap)) - for k, v := range completionRatioMap { - compAny[k] = v - } - converted["completion_ratio"] = compAny - } + if len(completionRatioMap) > 0 { + compAny := make(map[string]any, len(completionRatioMap)) + for k, v := range completionRatioMap { + compAny[k] = v + } + converted["completion_ratio"] = compAny + } - if len(modelPriceMap) > 0 { - priceAny := make(map[string]any, len(modelPriceMap)) - for k, v := range modelPriceMap { - priceAny[k] = v - } - converted["model_price"] = priceAny - } + if len(modelPriceMap) > 0 { + priceAny := make(map[string]any, len(modelPriceMap)) + for k, v := range modelPriceMap { + priceAny[k] = v + } + converted["model_price"] = priceAny + } - ch <- upstreamResult{Name: uniqueName, Data: converted} - }(chn) - } + ch <- upstreamResult{Name: uniqueName, Data: converted} + }(chn) + } - wg.Wait() - close(ch) + wg.Wait() + close(ch) - localData := ratio_setting.GetExposedData() + localData := ratio_setting.GetExposedData() - var testResults []dto.TestResult - var successfulChannels []struct { - name string - data map[string]any - } + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } - for r := range ch { - if r.Err != "" { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "error", - Error: r.Err, - }) - } else { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "success", - }) - successfulChannels = append(successfulChannels, struct { - name string - data map[string]any - }{name: r.Name, data: r.Data}) - } - } + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } - differences := buildDifferences(localData, successfulChannels) + differences := buildDifferences(localData, successfulChannels) - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "differences": differences, - "test_results": testResults, - }, - }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) } func buildDifferences(localData map[string]any, successfulChannels []struct { - name string - data map[string]any + name string + data map[string]any }) map[string]map[string]dto.DifferenceItem { - differences := make(map[string]map[string]dto.DifferenceItem) + differences := make(map[string]map[string]dto.DifferenceItem) - allModels := make(map[string]struct{}) - - for _, ratioType := range ratioTypes { - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - for modelName := range localRatio { - allModels[modelName] = struct{}{} - } - } - } - } - - for _, channel := range successfulChannels { - for _, ratioType := range ratioTypes { - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - for modelName := range upstreamRatio { - allModels[modelName] = struct{}{} - } - } - } - } + allModels := make(map[string]struct{}) - confidenceMap := make(map[string]map[string]bool) - - // 预处理阶段:检查pricing接口的可信度 - for _, channel := range successfulChannels { - confidenceMap[channel.name] = make(map[string]bool) - - modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) - completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - - if hasModelRatio && hasCompletionRatio { - // 遍历所有模型,检查是否满足不可信条件 - for modelName := range allModels { - // 默认为可信 - confidenceMap[channel.name][modelName] = true - - // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 - if modelRatioVal, ok := modelRatios[modelName]; ok { - if completionRatioVal, ok := completionRatios[modelName]; ok { - // 转换为float64进行比较 - if modelRatioFloat, ok := modelRatioVal.(float64); ok { - if completionRatioFloat, ok := completionRatioVal.(float64); ok { - if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { - confidenceMap[channel.name][modelName] = false - } - } - } - } - } - } - } else { - // 如果不是从pricing接口获取的数据,则全部标记为可信 - for modelName := range allModels { - confidenceMap[channel.name][modelName] = true - } - } - } + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } - for modelName := range allModels { - for _, ratioType := range ratioTypes { - var localValue interface{} = nil - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - if val, exists := localRatio[modelName]; exists { - localValue = val - } - } - } + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } - upstreamValues := make(map[string]interface{}) - confidenceValues := make(map[string]bool) - hasUpstreamValue := false - hasDifference := false + confidenceMap := make(map[string]map[string]bool) - for _, channel := range successfulChannels { - var upstreamValue interface{} = nil - - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - if val, exists := upstreamRatio[modelName]; exists { - upstreamValue = val - hasUpstreamValue = true - - if localValue != nil && localValue != val { - hasDifference = true - } else if localValue == val { - upstreamValue = "same" - } - } - } - if upstreamValue == nil && localValue == nil { - upstreamValue = "same" - } - - if localValue == nil && upstreamValue != nil && upstreamValue != "same" { - hasDifference = true - } - - upstreamValues[channel.name] = upstreamValue - - confidenceValues[channel.name] = confidenceMap[channel.name][modelName] - } + // 预处理阶段:检查pricing接口的可信度 + for _, channel := range successfulChannels { + confidenceMap[channel.name] = make(map[string]bool) - shouldInclude := false - - if localValue != nil { - if hasDifference { - shouldInclude = true - } - } else { - if hasUpstreamValue { - shouldInclude = true - } - } + modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) + completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - if shouldInclude { - if differences[modelName] == nil { - differences[modelName] = make(map[string]dto.DifferenceItem) - } - differences[modelName][ratioType] = dto.DifferenceItem{ - Current: localValue, - Upstreams: upstreamValues, - Confidence: confidenceValues, - } - } - } - } + if hasModelRatio && hasCompletionRatio { + // 遍历所有模型,检查是否满足不可信条件 + for modelName := range allModels { + // 默认为可信 + confidenceMap[channel.name][modelName] = true - channelHasDiff := make(map[string]bool) - for _, ratioMap := range differences { - for _, item := range ratioMap { - for chName, val := range item.Upstreams { - if val != nil && val != "same" { - channelHasDiff[chName] = true - } - } - } - } + // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 + if modelRatioVal, ok := modelRatios[modelName]; ok { + if completionRatioVal, ok := completionRatios[modelName]; ok { + // 转换为float64进行比较 + if modelRatioFloat, ok := modelRatioVal.(float64); ok { + if completionRatioFloat, ok := completionRatioVal.(float64); ok { + if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { + confidenceMap[channel.name][modelName] = false + } + } + } + } + } + } + } else { + // 如果不是从pricing接口获取的数据,则全部标记为可信 + for modelName := range allModels { + confidenceMap[channel.name][modelName] = true + } + } + } - for modelName, ratioMap := range differences { - for ratioType, item := range ratioMap { - for chName := range item.Upstreams { - if !channelHasDiff[chName] { - delete(item.Upstreams, chName) - delete(item.Confidence, chName) - } - } + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } - allSame := true - for _, v := range item.Upstreams { - if v != "same" { - allSame = false - break - } - } - if len(item.Upstreams) == 0 || allSame { - delete(ratioMap, ratioType) - } else { - differences[modelName][ratioType] = item - } - } + upstreamValues := make(map[string]interface{}) + confidenceValues := make(map[string]bool) + hasUpstreamValue := false + hasDifference := false - if len(ratioMap) == 0 { - delete(differences, modelName) - } - } + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil - return differences + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + + confidenceValues[channel.name] = confidenceMap[channel.name][modelName] + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + Confidence: confidenceValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + delete(item.Confidence, chName) + } + } + + allSame := true + for _, v := range item.Upstreams { + if v != "same" { + allSame = false + break + } + } + if len(item.Upstreams) == 0 || allSame { + delete(ratioMap, ratioType) + } else { + differences[modelName][ratioType] = item + } + } + + if len(ratioMap) == 0 { + delete(differences, modelName) + } + } + + return differences } func GetSyncableChannels(c *gin.Context) { - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } - var syncableChannels []dto.SyncableChannel - for _, channel := range channels { - if channel.GetBaseURL() != "" { - syncableChannels = append(syncableChannels, dto.SyncableChannel{ - ID: channel.Id, - Name: channel.Name, - BaseURL: channel.GetBaseURL(), - Status: channel.Status, - }) - } - } + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": syncableChannels, - }) -} \ No newline at end of file + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} diff --git a/controller/relay.go b/controller/relay.go index d235f550..b0c995fb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,21 +2,22 @@ package controller import ( "bytes" - "errors" "fmt" "io" "log" "net/http" "one-api/common" "one-api/constant" - constant2 "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" + relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting" "one-api/types" "strings" @@ -24,81 +25,177 @@ import ( "github.com/gorilla/websocket" ) -func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { +func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { var err *types.NewAPIError - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - err = relay.ImageHelper(c) + err = relay.ImageHelper(c, info) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c) + err = relay.AudioHelper(c, info) case relayconstant.RelayModeRerank: - err = relay.RerankHelper(c, relayMode) + err = relay.RerankHelper(c, info) case relayconstant.RelayModeEmbeddings: - err = relay.EmbeddingHelper(c) + err = relay.EmbeddingHelper(c, info) case relayconstant.RelayModeResponses: - err = relay.ResponsesHelper(c) - case relayconstant.RelayModeGemini: - if strings.Contains(c.Request.URL.Path, "embed") { - err = relay.GeminiEmbeddingHandler(c) - } else { - err = relay.GeminiHelper(c) - } + err = relay.ResponsesHelper(c, info) default: - err = relay.TextHelper(c) + err = relay.TextHelper(c, info) } - - if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) { - // 保存错误日志到mysql中 - userId := c.GetInt("id") - tokenName := c.GetString("token_name") - modelName := c.GetString("original_model") - tokenId := c.GetInt("token_id") - userGroup := c.GetString("group") - channelId := c.GetInt("channel_id") - other := make(map[string]interface{}) - other["error_type"] = err.GetErrorType() - other["error_code"] = err.GetErrorCode() - other["status_code"] = err.StatusCode - other["channel_id"] = channelId - other["channel_name"] = c.GetString("channel_name") - other["channel_type"] = c.GetInt("channel_type") - adminInfo := make(map[string]interface{}) - adminInfo["use_channel"] = c.GetStringSlice("use_channel") - isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) - if isMultiKey { - adminInfo["is_multi_key"] = true - adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) - } - other["admin_info"] = adminInfo - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) - } - return err } -func Relay(c *gin.Context) { - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) +func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { + var err *types.NewAPIError + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c, info) + } else { + err = relay.GeminiHelper(c, info) + } + return err +} + +func Relay(c *gin.Context, relayFormat types.RelayFormat) { + requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError + + var ( + newAPIError *types.NewAPIError + ws *websocket.Conn + ) + + if relayFormat == types.RelayFormatOpenAIRealtime { + var err error + ws, err = upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) + return + } + defer ws.Close() + } + + defer func() { + if newAPIError != nil { + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + helper.WssError(c, ws, newAPIError.ToOpenAIError()) + case types.RelayFormatClaude: + c.JSON(newAPIError.StatusCode, gin.H{ + "type": "error", + "error": newAPIError.ToClaudeError(), + }) + default: + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), + }) + } + } + }() + + request, err := helper.GetAndValidateRequest(c, relayFormat) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + return + } + + relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) + return + } + + meta := request.GetTokenCountMeta() + + if setting.ShouldCheckPromptSensitive() { + contains, words := service.CheckSensitiveText(meta.CombineText) + if contains { + logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) + newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return + } + } + + tokens, err := service.CountRequestToken(c, meta, relayInfo) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed) + return + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) + return + } + + preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return + } + + defer func() { + // Only return quota if downstream failed and quota was actually pre-consumed + if newAPIError != nil && preConsumedQuota != 0 { + service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + } + }() for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) newAPIError = err break } - newAPIError = relayRequest(c, relayMode, channel) + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + newAPIError = relay.WssHelper(c, relayInfo) + case types.RelayFormatClaude: + newAPIError = relay.ClaudeHelper(c, relayInfo) + case types.RelayFormatGemini: + newAPIError = geminiRelayHandler(c, relayInfo) + default: + newAPIError = relayHandler(c, relayInfo) + } if newAPIError == nil { - return // 成功处理请求,直接返回 + return + } else { + if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) { + // 保存错误日志到mysql中 + userId := c.GetInt("id") + tokenName := c.GetString("token_name") + modelName := c.GetString("original_model") + tokenId := c.GetInt("token_id") + userGroup := c.GetString("group") + channelId := c.GetInt("channel_id") + other := make(map[string]interface{}) + other["error_type"] = newAPIError.GetErrorType() + other["error_code"] = newAPIError.GetErrorCode() + other["status_code"] = newAPIError.StatusCode + other["channel_id"] = channelId + other["channel_name"] = c.GetString("channel_name") + other["channel_type"] = c.GetInt("channel_type") + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = c.GetStringSlice("use_channel") + isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) + if isMultiKey { + adminInfo["is_multi_key"] = true + adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) + } + other["admin_info"] = adminInfo + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other) + } } go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) @@ -107,21 +204,11 @@ func Relay(c *gin.Context) { break } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "error": newAPIError.ToOpenAIError(), - }) + logger.LogInfo(c, retryLogStr) } } @@ -132,122 +219,6 @@ var upgrader = websocket.Upgrader{ }, } -func WssRelay(c *gin.Context) { - // 将 HTTP 连接升级为 WebSocket 连接 - - ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) - defer ws.Close() - - if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) - return - } - - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = wssRequest(c, ws, relayMode, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - helper.WssError(c, ws, newAPIError.ToOpenAIError()) - } -} - -func RelayClaude(c *gin.Context) { - //relayMode := constant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = claudeRequest(c, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "type": "error", - "error": newAPIError.ToClaudeError(), - }) - } -} - -func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relayHandler(c, relayMode) -} - -func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.WssHelper(c, ws) -} - -func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.ClaudeHelper(c) -} - func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) @@ -270,10 +241,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -327,42 +298,52 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { service.DisableChannel(channelError, err.Error()) } } func RelayMidjourney(c *gin.Context) { - relayMode := c.GetInt("relay_mode") - var err *dto.MidjourneyResponse - switch relayMode { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil) + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()), + "type": "upstream_error", + "code": 4, + }) + return + } + + var mjErr *dto.MidjourneyResponse + switch relayInfo.RelayMode { case relayconstant.RelayModeMidjourneyNotify: - err = relay.RelayMidjourneyNotify(c) + mjErr = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: - err = relay.RelayMidjourneyTask(c, relayMode) + mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: - err = relay.RelayMidjourneyTaskImageSeed(c) + mjErr = relay.RelayMidjourneyTaskImageSeed(c) case relayconstant.RelayModeSwapFace: - err = relay.RelaySwapFace(c) + mjErr = relay.RelaySwapFace(c, relayInfo) default: - err = relay.RelayMidjourneySubmit(c, relayMode) + mjErr = relay.RelayMidjourneySubmit(c, relayInfo) } //err = relayMidjourneySubmit(c, relayMode) - log.Println(err) - if err != nil { + log.Println(mjErr) + if mjErr != nil { statusCode := http.StatusBadRequest - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + if mjErr.Code == 30 { + mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" statusCode = http.StatusTooManyRequests } c.JSON(statusCode, gin.H{ - "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result), "type": "upstream_error", - "code": err.Code, + "code": mjErr.Code, }) channelId := c.GetInt("channel_id") - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result))) } } @@ -404,7 +385,7 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, newAPIError := getChannel(c, group, originalModel, i) if newAPIError != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } @@ -412,7 +393,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, _ := common.GetRequestBody(c) @@ -422,7 +403,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) + logger.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/controller/task.go b/controller/task.go index 5fbdb424..1082d7a1 100644 --- a/controller/task.go +++ b/controller/task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "sort" @@ -54,9 +55,9 @@ func UpdateTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -86,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { - common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) } } return nil } func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -106,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "progress": "100%", }) if err != nil { - common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -118,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) return err } if !responseItems.IsSuccess() { @@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } @@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index 914bf6e6..ffb6728b 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "one-api/relay/channel" @@ -18,14 +19,14 @@ import ( func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) } } return nil } func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -37,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -47,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha } for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) } } return nil @@ -61,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task := taskM[taskId] if task == nil { - common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) return fmt.Errorf("task %s not found", taskId) } resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ @@ -112,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.StartTime = now } case model.TaskStatusSuccess: - task.Progress = "100%" + task.Progress = "100%" if task.FinishTime == 0 { task.FinishTime = now } @@ -124,13 +125,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.FinishTime = now } task.FailReason = taskResult.Reason - common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) quota := task.Quota if quota != 0 { if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + logger.LogError(ctx, "Failed to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } default: @@ -140,7 +141,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.Progress = taskResult.Progress } if err := task.Update(); err != nil { - common.SysError("UpdateVideoTask task error: " + err.Error()) + common.SysLog("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index 62eb5474..399ccb4f 100644 --- a/controller/token.go +++ b/controller/token.go @@ -102,7 +102,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/topup.go b/controller/topup.go index 827dda39..3f3c8623 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -5,6 +5,7 @@ import ( "log" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) { return } log.Printf("易支付回调更新用户成功 %v", topUp) - model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) diff --git a/controller/twofa.go b/controller/twofa.go index 9f48eed8..1859a128 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -70,7 +70,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成2FA密钥失败", }) - common.SysError("生成TOTP密钥失败: " + err.Error()) + common.SysLog("生成TOTP密钥失败: " + err.Error()) return } @@ -81,7 +81,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -115,7 +115,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } @@ -294,7 +294,7 @@ func Get2FAStatus(c *gin.Context) { // 获取剩余备用码数量 backupCount, err := model.GetUnusedBackupCodeCount(userId) if err != nil { - common.SysError("获取备用码数量失败: " + err.Error()) + common.SysLog("获取备用码数量失败: " + err.Error()) } else { status["backup_codes_remaining"] = backupCount } @@ -368,7 +368,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - common.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -378,7 +378,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - common.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } diff --git a/controller/user.go b/controller/user.go index 29cf83e1..a7d59f17 100644 --- a/controller/user.go +++ b/controller/user.go @@ -7,6 +7,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/setting" "strconv" @@ -192,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -235,7 +236,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - common.SysError("failed to generate key: " + err.Error()) + common.SysLog("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) @@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/dto/audio.go b/dto/audio.go index c36b3da5..81872c69 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,5 +1,11 @@ package dto +import ( + "one-api/types" + + "github.com/gin-gonic/gin" +) + type AudioRequest struct { Model string `json:"model"` Input string `json:"input"` @@ -8,6 +14,18 @@ type AudioRequest struct { ResponseFormat string `json:"response_format,omitempty"` } +func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta { + meta := &types.TokenCountMeta{ + CombineText: r.Input, + TokenType: types.TokenTypeTextNumber, + } + return meta +} + +func (r *AudioRequest) IsStream(c *gin.Context) bool { + return false +} + type AudioResponse struct { Text string `json:"text"` } diff --git a/dto/claude.go b/dto/claude.go index 58a09217..48bef659 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -5,6 +5,9 @@ import ( "fmt" "one-api/common" "one-api/types" + "strings" + + "github.com/gin-gonic/gin" ) type ClaudeMetadata struct { @@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string { } func (c *ClaudeMediaMessage) GetJsonRowString() string { - jsonContent, _ := json.Marshal(c) + jsonContent, _ := common.Marshal(c) return string(jsonContent) } @@ -199,6 +202,129 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta = types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + MaxTokens: int(c.MaxTokens), + } + + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + // system + if c.System != nil { + if c.IsStringSystem() { + sys := c.GetStringSystem() + if sys != "" { + texts = append(texts, sys) + } + } else { + systemMedia := c.ParseSystem() + for _, media := range systemMedia { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + } + } + } + } + + // messages + for _, message := range c.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.IsStringContent() { + content := message.GetStringContent() + if content != "" { + texts = append(texts, content) + } + continue + } + + content, _ := message.ParseContent() + for _, media := range content { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + } + } + case "tool_use": + if media.Name != "" { + texts = append(texts, media.Name) + } + if media.Input != nil { + b, _ := common.Marshal(media.Input) + texts = append(texts, string(b)) + } + case "tool_result": + if media.Content != nil { + b, _ := common.Marshal(media.Content) + texts = append(texts, string(b)) + } + } + } + } + + // tools + if c.Tools != nil { + tools := c.GetTools() + normalTools, webSearchTools := ProcessTools(tools) + if normalTools != nil { + for _, t := range normalTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.Description != "" { + texts = append(texts, t.Description) + } + if t.InputSchema != nil { + b, _ := common.Marshal(t.InputSchema) + texts = append(texts, string(b)) + } + } + } + if webSearchTools != nil { + for _, t := range webSearchTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.UserLocation != nil { + b, _ := common.Marshal(t.UserLocation) + texts = append(texts, string(b)) + } + } + } + } + + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool { + return claudeRequest.Stream +} + func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { for _, message := range c.Messages { content, _ := message.ParseContent() diff --git a/dto/dalle.go b/dto/dalle.go deleted file mode 100644 index d1e66de9..00000000 --- a/dto/dalle.go +++ /dev/null @@ -1,32 +0,0 @@ -package dto - -import "encoding/json" - -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style json.RawMessage `json:"style,omitempty"` - User json.RawMessage `json:"user,omitempty"` - ExtraFields json.RawMessage `json:"extra_fields,omitempty"` - Background json.RawMessage `json:"background,omitempty"` - Moderation json.RawMessage `json:"moderation,omitempty"` - OutputFormat json.RawMessage `json:"output_format,omitempty"` - OutputCompression json.RawMessage `json:"output_compression,omitempty"` - PartialImages json.RawMessage `json:"partial_images,omitempty"` - // Stream bool `json:"stream,omitempty"` - Watermark *bool `json:"watermark,omitempty"` -} - -type ImageResponse struct { - Data []ImageData `json:"data"` - Created int64 `json:"created"` -} -type ImageData struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - RevisedPrompt string `json:"revised_prompt"` -} diff --git a/dto/embedding.go b/dto/embedding.go index 9d722292..fff37776 100644 --- a/dto/embedding.go +++ b/dto/embedding.go @@ -1,5 +1,12 @@ package dto +import ( + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` @@ -24,9 +31,26 @@ type EmbeddingRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } -func (r EmbeddingRequest) ParseInput() []string { +func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + inputs := r.ParseInput() + for _, input := range inputs { + texts = append(texts, input) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + +func (r *EmbeddingRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *EmbeddingRequest) ParseInput() []string { if r.Input == nil { - return nil + return make([]string, 0) } var input []string switch r.Input.(type) { diff --git a/dto/gemini.go b/dto/gemini.go index 6cb3e17a..b327de62 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,7 +2,10 @@ package dto import ( "encoding/json" + "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" + "one-api/types" "strings" ) @@ -14,19 +17,75 @@ type GeminiChatRequest struct { SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } +func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { + var files []*types.FileMeta = make([]*types.FileMeta, 0) + + var maxTokens int + + if r.GenerationConfig.MaxOutputTokens > 0 { + maxTokens = int(r.GenerationConfig.MaxOutputTokens) + } + + var inputTexts []string + for _, content := range r.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + if part.InlineData != nil && part.InlineData.Data != "" { + if strings.HasPrefix(part.InlineData.MimeType, "image/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeAudio, + Data: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeVideo, + Data: part.InlineData.Data, + }) + } else { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: part.InlineData.Data, + }) + } + } + } + } + + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + Files: files, + MaxTokens: maxTokens, + } +} + +func (r *GeminiChatRequest) IsStream(c *gin.Context) bool { + if c.Query("alt") == "sse" { + return true + } + return false +} + func (r *GeminiChatRequest) GetTools() []GeminiChatTool { var tools []GeminiChatTool if strings.HasSuffix(string(r.Tools), "[") { // is array if err := common.Unmarshal(r.Tools, &tools); err != nil { - common.LogError(nil, "error_unmarshalling_tools: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_tools: "+err.Error()) return nil } } else if strings.HasPrefix(string(r.Tools), "{") { // is object singleTool := GeminiChatTool{} if err := common.Unmarshal(r.Tools, &singleTool); err != nil { - common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) + logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) return nil } tools = []GeminiChatTool{singleTool} @@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) { // Marshal the tools to JSON data, err := common.Marshal(tools) if err != nil { - common.LogError(nil, "error_marshalling_tools: "+err.Error()) + logger.LogError(nil, "error_marshalling_tools: "+err.Error()) return } r.Tools = data diff --git a/dto/openai_image.go b/dto/openai_image.go new file mode 100644 index 00000000..970f5bb4 --- /dev/null +++ b/dto/openai_image.go @@ -0,0 +1,74 @@ +package dto + +import ( + "encoding/json" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N uint `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style json.RawMessage `json:"style,omitempty"` + User json.RawMessage `json:"user,omitempty"` + ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background json.RawMessage `json:"background,omitempty"` + Moderation json.RawMessage `json:"moderation,omitempty"` + OutputFormat json.RawMessage `json:"output_format,omitempty"` + OutputCompression json.RawMessage `json:"output_compression,omitempty"` + PartialImages json.RawMessage `json:"partial_images,omitempty"` + // Stream bool `json:"stream,omitempty"` + Watermark *bool `json:"watermark,omitempty"` +} + +func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { + var sizeRatio = 1.0 + var qualityRatio = 1.0 + + if strings.HasPrefix(i.Model, "dall-e") { + // Size + if i.Size == "256x256" { + sizeRatio = 0.4 + } else if i.Size == "512x512" { + sizeRatio = 0.45 + } else if i.Size == "1024x1024" { + sizeRatio = 1 + } else if i.Size == "1024x1792" || i.Size == "1792x1024" { + sizeRatio = 2 + } + + if i.Model == "dall-e-3" && i.Quality == "hd" { + qualityRatio = 2.0 + if i.Size == "1024x1792" || i.Size == "1792x1024" { + qualityRatio = 1.5 + } + } + } + + // not support token count for dalle + return &types.TokenCountMeta{ + CombineText: i.Prompt, + MaxTokens: 1584, + ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N), + } +} + +func (i *ImageRequest) IsStream(c *gin.Context) bool { + return false +} + +type ImageResponse struct { + Data []ImageData `json:"data"` + Created int64 `json:"created"` +} +type ImageData struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 7a23ca5c..12aa54f4 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,8 +2,12 @@ package dto import ( "encoding/json" + "fmt" "one-api/common" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type ResponseFormat struct { @@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct { Extra map[string]json.RawMessage `json:"-"` } +func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta types.TokenCountMeta + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + texts = append(texts, v) + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + texts = append(texts, str) + } + } + default: + texts = append(texts, fmt.Sprintf("%v", r.Prompt)) + } + } + + if r.Input != nil { + inputs := r.ParseInput() + texts = append(texts, inputs...) + } + + if r.MaxCompletionTokens > r.MaxTokens { + tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens) + } else { + tokenCountMeta.MaxTokens = int(r.MaxTokens) + } + + for _, message := range r.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.Content != nil { + if message.Name != nil { + tokenCountMeta.NameCount++ + texts = append(texts, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == ContentTypeImageURL { + imageUrl := m.GetImageMedia() + if imageUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeImage, + } + meta.Data = imageUrl.Url + meta.Detail = imageUrl.Detail + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeInputAudio { + inputAudio := m.GetInputAudio() + if inputAudio != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeAudio, + } + meta.Data = inputAudio.Data + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeFile { + file := m.GetFile() + if file != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeFile, + } + meta.Data = file.FileData + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeVideoUrl { + videoUrl := m.GetVideoUrl() + if videoUrl != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeVideo, + } + meta.Data = videoUrl.Url + fileMeta = append(fileMeta, meta) + } + } else { + texts = append(texts, m.Text) + } + } + } + } + + if r.Tools != nil { + openaiTools := r.Tools + for _, tool := range openaiTools { + tokenCountMeta.ToolsCount++ + texts = append(texts, tool.Function.Name) + if tool.Function.Description != "" { + texts = append(texts, tool.Function.Description) + } + if tool.Function.Parameters != nil { + texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters)) + } + } + //toolTokens := CountTokenInput(countStr, request.Model) + //tkm += 8 + //tkm += toolTokens + } + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.Marshal(r) @@ -202,6 +316,21 @@ func (m *MediaContent) GetFile() *MessageFile { return nil } +func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { + if m.VideoUrl != nil { + if _, ok := m.VideoUrl.(*MessageVideoUrl); ok { + return m.VideoUrl.(*MessageVideoUrl) + } + if itemMap, ok := m.VideoUrl.(map[string]any); ok { + out := &MessageVideoUrl{ + Url: common.Interface2String(itemMap["url"]), + } + return out + } + } + return nil +} + type MessageImageUrl struct { Url string `json:"url"` Detail string `json:"detail"` @@ -233,6 +362,7 @@ const ( ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 + //ContentTypeAudioUrl = "audio_url" ) func (m *Message) GetPrefix() bool { @@ -623,7 +753,7 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` + Input any `json:"input,omitempty"` Include json.RawMessage `json:"include,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` @@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct { Prompt json.RawMessage `json:"prompt,omitempty"` } +func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { + var fileMeta = make([]*types.FileMeta, 0) + var texts = make([]string, 0) + + if r.Input != nil { + inputs := r.ParseInput() + for _, input := range inputs { + if input.Type == "input_image" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Data: input.ImageUrl, + Detail: input.Detail, + }) + } else if input.Type == "input_file" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeFile, + Data: input.FileUrl, + }) + } else { + texts = append(texts, input.Text) + } + } + } + + if len(r.Instructions) > 0 { + texts = append(texts, string(r.Instructions)) + } + + if len(r.Metadata) > 0 { + texts = append(texts, string(r.Metadata)) + } + + if len(r.Text) > 0 { + texts = append(texts, string(r.Text)) + } + + if len(r.ToolChoice) > 0 { + texts = append(texts, string(r.ToolChoice)) + } + + if len(r.Prompt) > 0 { + texts = append(texts, string(r.Prompt)) + } + + if len(r.Tools) > 0 { + toolStr, _ := common.Marshal(r.Tools) + texts = append(texts, string(toolStr)) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + Files: fileMeta, + MaxTokens: int(r.MaxOutputTokens), + } +} + +func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + type Reasoning struct { Effort string `json:"effort,omitempty"` Summary string `json:"summary,omitempty"` } -//type ResponsesToolsCall struct { -// Type string `json:"type"` -// // Web Search -// UserLocation json.RawMessage `json:"user_location,omitempty"` -// SearchContextSize string `json:"search_context_size,omitempty"` -// // File Search -// VectorStoreIds []string `json:"vector_store_ids,omitempty"` -// MaxNumResults uint `json:"max_num_results,omitempty"` -// Filters json.RawMessage `json:"filters,omitempty"` -// // Computer Use -// DisplayWidth uint `json:"display_width,omitempty"` -// DisplayHeight uint `json:"display_height,omitempty"` -// Environment string `json:"environment,omitempty"` -// // Function -// Name string `json:"name,omitempty"` -// Description string `json:"description,omitempty"` -// Parameters json.RawMessage `json:"parameters,omitempty"` -// Function json.RawMessage `json:"function,omitempty"` -// Container json.RawMessage `json:"container,omitempty"` -//} +type MediaInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + FileUrl string `json:"file_url,omitempty"` + ImageUrl string `json:"image_url,omitempty"` + Detail string `json:"detail,omitempty"` // 仅 input_image 有效 +} + +// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput. +// Reference implementation mirrors Message.ParseContent: +// - input can be a string, treated as an input_text item +// - input can be an array of objects with a `type` field +// supported types: input_text, input_image, input_file +func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { + if r.Input == nil { + return nil + } + + var inputs []MediaInput + + // Try string first + if str, ok := r.Input.(string); ok { + inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + return inputs + } + + // Try array of parts + if array, ok := r.Input.([]any); ok { + for _, itemAny := range array { + // Already parsed MediaInput + if media, ok := itemAny.(MediaInput); ok { + inputs = append(inputs, media) + continue + } + // Generic map + item, ok := itemAny.(map[string]any) + if !ok { + continue + } + typeVal, ok := item["type"].(string) + if !ok { + continue + } + switch typeVal { + case "input_text": + text, _ := item["text"].(string) + inputs = append(inputs, MediaInput{Type: "input_text", Text: text}) + case "input_image": + // image_url may be string or object with url field + var imageUrl string + switch v := item["image_url"].(type) { + case string: + imageUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + imageUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl}) + case "input_file": + // file_url may be string or object with url field + var fileUrl string + switch v := item["file_url"].(type) { + case string: + fileUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + fileUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl}) + } + } + } + + return inputs +} diff --git a/dto/request_common.go b/dto/request_common.go new file mode 100644 index 00000000..8bd25785 --- /dev/null +++ b/dto/request_common.go @@ -0,0 +1,24 @@ +package dto + +import ( + "github.com/gin-gonic/gin" + "one-api/types" +) + +type Request interface { + GetTokenCountMeta() *types.TokenCountMeta + IsStream(c *gin.Context) bool +} + +type BaseRequest struct { +} + +func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta { + return &types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + } +} + +func (b *BaseRequest) IsStream(c *gin.Context) bool { + return false +} diff --git a/dto/rerank.go b/dto/rerank.go index 5ea68cba..ca4da9e1 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -1,5 +1,12 @@ package dto +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/types" + "strings" +) + type RerankRequest struct { Documents []any `json:"documents"` Query string `json:"query"` @@ -10,6 +17,26 @@ type RerankRequest struct { OverLapTokens int `json:"overlap_tokens,omitempty"` } +func (r *RerankRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + for _, document := range r.Documents { + texts = append(texts, fmt.Sprintf("%v", document)) + } + + if r.Query != "" { + texts = append(texts, r.Query) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + func (r *RerankRequest) GetReturnDocuments() bool { if r.ReturnDocuments == nil { return false diff --git a/common/logger.go b/logger/logger.go similarity index 70% rename from common/logger.go rename to logger/logger.go index 0f6dc3c3..d59e51cb 100644 --- a/common/logger.go +++ b/logger/logger.go @@ -1,23 +1,26 @@ -package common +package logger import ( "context" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" "io" "log" + "one-api/common" "os" "path/filepath" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) const ( loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" + loggerDebug = "DEBUG" ) const maxLogCount = 1000000 @@ -27,7 +30,10 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *LogDir != "" { + defer func() { + setupLogWorking = false + }() + if *common.LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -35,9 +41,8 @@ func SetupLogger() { } defer func() { setupLogLock.Unlock() - setupLogWorking = false }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -47,16 +52,6 @@ func SetupLogger() { } } -func SysLog(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) -} - -func SysError(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) -} - func LogInfo(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } @@ -69,12 +64,18 @@ func LogError(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func LogDebug(ctx context.Context, msg string) { + if common.DebugEnabled { + logHelper(ctx, loggerDebug, msg) + } +} + func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(common.RequestIdKey) if id == nil { id = "SYSTEM" } @@ -90,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) { } } -func FatalLog(v ...any) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) - os.Exit(1) -} - func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) } else { return fmt.Sprintf("%d 点额度", quota) } } func FormatQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit) } else { return fmt.Sprintf("%d", quota) } diff --git a/main.go b/main.go index ca3da601..2dfddacc 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/constant" "one-api/controller" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/router" @@ -60,13 +61,13 @@ func main() { } if common.MemoryCacheEnabled { common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { @@ -125,7 +126,7 @@ func main() { // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -171,7 +172,7 @@ func InitResources() error { // 加载环境变量 common.InitEnv() - common.SetupLogger() + logger.SetupLogger() // Initialize model settings ratio_setting.InitRatioSettings() diff --git a/middleware/distributor.go b/middleware/distributor.go index 286a4d1f..28b66a3a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) { // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) // message = "数据库一致性已被破坏,请联系管理员" //} - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound)) return } if channel == nil { - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model)) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound)) return } } diff --git a/middleware/recover.go b/middleware/recover.go index 51fc7190..d78c8137 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..106a7278 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -37,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 082f5657..77d1eb80 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -4,18 +4,24 @@ import ( "fmt" "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" ) -func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) { + codeStr := "" + if len(code) > 0 { + codeStr = code[0] + } userId := c.GetInt("id") c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "type": "new_api_error", + "code": codeStr, }, }) c.Abort() - common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) + logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) } func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { @@ -25,5 +31,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri "code": code, }) c.Abort() - common.LogError(c.Request.Context(), description) + logger.LogError(c.Request.Context(), description) } diff --git a/model/ability.go b/model/ability.go index ce2f299c..123fc7be 100644 --- a/model/ability.go +++ b/model/ability.go @@ -294,13 +294,13 @@ func FixAbility() (int, int, error) { if common.UsingSQLite { err := DB.Exec("DELETE FROM abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) return 0, 0, err } } else { err := DB.Exec("TRUNCATE TABLE abilities").Error if err != nil { - common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) return 0, 0, err } } @@ -320,7 +320,7 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } @@ -328,7 +328,7 @@ func FixAbility() (int, int, error) { for _, channel := range chunk { err = channel.AddAbilities(nil) if err != nil { - common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index 6239f05c..a9a23481 100644 --- a/model/channel.go +++ b/model/channel.go @@ -209,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - common.SysError("failed to unmarshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) } } return otherInfo @@ -218,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - common.SysError("failed to marshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) return } channel.OtherInfo = string(otherInfoBytes) @@ -406,7 +406,11 @@ func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" } - return *channel.BaseURL + url := *channel.BaseURL + if url == "" { + url = constant.ChannelBaseURLs[channel.Type] + } + return url } func (channel *Channel) GetModelMapping() string { @@ -488,7 +492,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err)) } } @@ -498,7 +502,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err)) } } @@ -614,7 +618,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err)) } } }() @@ -642,7 +646,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err)) return false } } @@ -704,7 +708,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - common.SysError("failed to update abilities: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err)) } } } @@ -728,7 +732,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err)) } } @@ -821,7 +825,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -832,7 +836,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -843,7 +847,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -854,7 +858,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) return } channel.OtherSettings = string(settingBytes) @@ -865,7 +869,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - common.SysError("failed to unmarshal param override: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) } } return paramOverride diff --git a/model/log.go b/model/log.go index 2070cd6f..e443516d 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/logger" "os" "strings" "time" @@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.SysLog("failed to record log: " + err.Error()) } } func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) + logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) username := c.GetString("username") otherStr := common.MapToJsonStr(other) // 判断是否需要记录 IP @@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } } @@ -142,7 +143,6 @@ type RecordConsumeLogParams struct { Quota int `json:"quota"` Content string `json:"content"` TokenId int `json:"token_id"` - UserQuota int `json:"user_quota"` UseTimeSeconds int `json:"use_time_seconds"` IsStream bool `json:"is_stream"` Group string `json:"group"` @@ -150,7 +150,7 @@ type RecordConsumeLogParams struct { } func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) + logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } @@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } if common.DataExportEnabled { gopool.Go(func() { diff --git a/model/option.go b/model/option.go index 5c84d166..2121710c 100644 --- a/model/option.go +++ b/model/option.go @@ -150,7 +150,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + common.SysLog("failed to update option map: " + err.Error()) } } } diff --git a/model/pricing.go b/model/pricing.go index 0936d298..3c9349de 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -92,7 +92,7 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 diff --git a/model/redemption.go b/model/redemption.go index bf237668..1ab84f45 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strconv" "gorm.io/gorm" @@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id)) return redemption.Quota, nil } diff --git a/model/token.go b/model/token.go index e85a445e..320b5cf0 100644 --- a/model/token.go +++ b/model/token.go @@ -91,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -102,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -134,7 +134,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -147,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -178,7 +178,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -194,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -209,7 +209,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - common.SysError("failed to delete token cache: " + err.Error()) + common.SysLog("failed to delete token cache: " + err.Error()) } }) } @@ -269,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to increase token quota: " + err.Error()) + common.SysLog("failed to increase token quota: " + err.Error()) } }) } @@ -299,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to decrease token quota: " + err.Error()) + common.SysLog("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/topup.go b/model/topup.go index c34c0ce6..802c866f 100644 --- a/model/topup.go +++ b/model/topup.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "gorm.io/gorm" ) @@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) { return errors.New("充值失败," + err.Error()) } - RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount)) + RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) return nil } diff --git a/model/twofa.go b/model/twofa.go index d09ff9fe..8e97289f 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -243,7 +243,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { if !common.ValidateTOTPCode(t.Secret, code) { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -255,7 +255,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil @@ -277,7 +277,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { if !valid { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - common.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -289,7 +289,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - common.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil diff --git a/model/user.go b/model/user.go index 6021f495..29d7a446 100644 --- a/model/user.go +++ b/model/user.go @@ -6,6 +6,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "strconv" "strings" @@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) @@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) { func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { - return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) + return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 @@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } @@ -517,7 +518,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + common.SysLog("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - common.SysError("failed to update user quota cache: " + err.Error()) + common.SysLog("failed to update user quota cache: " + err.Error()) } }) } @@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - common.SysError("failed to update user group cache: " + err.Error()) + common.SysLog("failed to update user group cache: " + err.Error()) } }) } @@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - common.SysError("failed to update user setting cache: " + err.Error()) + common.SysLog("failed to update user setting cache: " + err.Error()) } }) } @@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to increase user quota: " + err.Error()) + common.SysLog("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to decrease user quota: " + err.Error()) + common.SysLog("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + common.SysLog("failed to update user used quota and request count: " + err.Error()) return } @@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + common.SysLog("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + common.SysLog("failed to update user request count: " + err.Error()) } } @@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - common.SysError("failed to update user name cache: " + err.Error()) + common.SysLog("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index a631457c..936e1a43 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -37,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -78,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index 1f8a0963..dced2bc6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -77,12 +77,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + common.SysLog("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + common.SysLog("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 88777838..1bc2d90b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -4,107 +4,40 @@ import ( "errors" "fmt" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/types" - "strings" "github.com/gin-gonic/gin" ) -func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { - audioRequest := &dto.AudioRequest{} - err := common.UnmarshalBodyReusable(c, audioRequest) - if err != nil { - return nil, err - } - switch info.RelayMode { - case relayconstant.RelayModeAudioSpeech: - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(audioRequest.Input) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - default: - err = c.Request.ParseForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - if audioRequest.Model == "" { - audioRequest.Model = formData.Get("model") - } +func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - audioRequest.ResponseFormat = formData.Get("response_format") - if audioRequest.ResponseFormat == "" { - audioRequest.ResponseFormat = "json" - } - } - return audioRequest, nil -} - -func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) - audioRequest, err := getAndValidAudioRequest(c, relayInfo) - - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + audioRequest, ok := info.Request.(*dto.AudioRequest) + if !ok { + return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - preConsumedTokens = promptTokens - relayInfo.PromptTokens = promptTokens - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr - } - defer func() { - if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - err = helper.ModelMappedHelper(c, relayInfo, audioRequest) + err := helper.ModelMappedHelper(c, info, audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) + ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - resp, err := adaptor.DoRequest(c, relayInfo, ioReader) + resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } @@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfb94008..0ae8a8d1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) + case types.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl) default: switch info.RelayMode { case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl) case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl) } } @@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 754f29c8..645882bc 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/service" "one-api/types" @@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { imageRequest.Input.Prompt = request.Prompt imageRequest.Model = request.Model imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) - imageRequest.Parameters.N = request.N + imageRequest.Parameters.N = int(request.N) imageRequest.ResponseFormat = request.ResponseFormat return &imageRequest } func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { - url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) + url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID) var aliResponse AliResponse @@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - common.SysError("updateTask client.Do err: " + err.Error()) + common.SysLog("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error var response AliResponse err = json.Unmarshal(responseBody, &response) if err != nil { - common.SysError("updateTask NewDecoder err: " + err.Error()) + common.SysLog("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } @@ -109,7 +110,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc if responseFormat == "b64_json" { _, b64, err := service.GetImageFromUrl(data.Url) if err != nil { - common.LogError(c, "get_image_data_failed: "+err.Error()) + logger.LogError(c, "get_image_data_failed: "+err.Error()) continue } b64Json = b64 @@ -134,14 +135,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { - common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 4f448e01..e7d6b514 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index fcf63854..67b63286 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/dto" "one-api/relay/helper" + "one-api/service" "strings" "one-api/types" @@ -46,7 +47,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) model := c.GetString("model") if model == "" { @@ -148,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -161,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -171,7 +172,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -181,7 +182,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 3ccd2d78..fd745cf7 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,6 +7,7 @@ import ( "io" "net/http" common2 "one-api/common" + "one-api/logger" "one-api/relay/common" "one-api/relay/constant" "one-api/relay/helper" @@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { err := helper.PingData(c) if err != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) + logger.LogError(c, "SSE ping error: "+err.Error()) done <- err return } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 8396a844..32e301ee 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { default: suffix += strings.ToLower(info.UpstreamModelName) } - fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix) var accessToken string var err error if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index a7cd5996..31e8319e 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -118,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -129,11 +129,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - common.SysError("error sending stream response: " + err.Error()) + common.SysLog("error sending stream response: " + err.Error()) } return true }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -143,7 +143,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil @@ -168,7 +168,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index ba59e307..6744f8ba 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 39b8ce2f..41583d30 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if a.RequestMode == RequestModeMessage { - return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index e4d3975e..57670bcf 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -609,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) if requestMode == RequestModeCompletion { @@ -628,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } helper.ClaudeChunkData(c, claudeResponse, data) - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { + } else if info.RelayFormat == types.RelayFormatOpenAI { response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { @@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud err = helper.ObjectData(c, response) if err != nil { - common.LogError(c, "send_stream_response_failed: "+err.Error()) + logger.LogError(c, "send_stream_response_failed: "+err.Error()) } } return nil @@ -653,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - common.SysError("claude response usage is not complete, maybe upstream error") + common.SysLog("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { // - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - + } else if info.RelayFormat == types.RelayFormatOpenAI { if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } } helper.Done(c) @@ -721,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } var responseData []byte switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: responseData = data } @@ -736,12 +736,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) } - common.IOCopyBytesGracefully(c, nil, responseData) + service.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 4b9f5028..bdea72f0 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeResponses: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil default: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil } } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 5e8fe7f9..00f6b6c5 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -5,8 +5,8 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res var response dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &response) if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) continue } for _, choice := range response.Choices { @@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res info.FirstResponseTime = time.Now() } if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) + logger.LogError(c, "error_rendering_stream_response: "+err.Error()) } } if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + logger.LogError(c, "error_scanning_stream_response: "+err.Error()) } usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + logger.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { @@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 887f9efd..c8a38d46 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index fcfb12b7..af357348 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -118,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -153,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) @@ -175,7 +175,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { @@ -216,7 +216,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 658c6193..0f2a6fd3 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 32cc6937..c480045f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -49,7 +49,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) // convert coze response to openai response var response dto.TextResponse var cozeResponse CozeChatDetailResponse @@ -154,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -171,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -203,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } - common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { - requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") // 将 conversationId和chatId作为参数发送get请求 @@ -258,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { - requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") req, err := http.NewRequest("GET", requestURL, nil) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index be8de0c8..17d732ab 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fimBaseUrl := info.BaseUrl - if !strings.HasSuffix(info.BaseUrl, "/beta") { + fimBaseUrl := info.ChannelBaseUrl + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { fimBaseUrl += "/beta" } switch info.RelayMode { case constant.RelayModeCompletions: return fmt.Sprintf("%s/completions", fimBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 8c7898c9..0a08d035 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch a.BotType { case BotTypeWorkFlow: - return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil case BotTypeCompletion: - return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil case BotTypeAgent: fallthrough default: - return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 47337127..2336fd4c 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -22,7 +22,7 @@ import ( ) func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { - uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) + uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl) switch media.Type { case dto.ContentTypeImageURL: // Decode base64 data @@ -36,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - common.SysError("failed to decode base64: " + err.Error()) + common.SysLog("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - common.SysError("failed to create temp file: " + err.Error()) + common.SysLog("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -51,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - common.SysError("failed to write to temp file: " + err.Error()) + common.SysLog("failed to write to temp file: " + err.Error()) return nil } @@ -61,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - common.SysError("failed to add user field: " + err.Error()) + common.SysLog("failed to add user field: " + err.Error()) return nil } @@ -74,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - common.SysError("failed to create form file: " + err.Error()) + common.SysLog("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - common.SysError("failed to copy file content: " + err.Error()) + common.SysLog("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -88,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - common.SysError("failed to create request: " + err.Error()) + common.SysLog("failed to create request: " + err.Error()) return nil } @@ -99,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - common.SysError("failed to send request: " + err.Error()) + common.SysLog("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -109,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - common.SysError("failed to decode response: " + err.Error()) + common.SysLog("failed to decode response: " + err.Error()) return nil } @@ -219,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -239,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) @@ -258,7 +258,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4141caf7..99b6645e 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf }, }, Parameters: dto.GeminiImageParameters{ - SampleCount: request.N, + SampleCount: int(request.N), AspectRatio: aspectRatio, PersonGeneration: "allow_adult", // default allow adult }, @@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { - return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil } if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || @@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsGeminiBatchEmbedding { action = "batchEmbedContents" } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" @@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { info.DisablePing = true } } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 7f2f51fb..974a22f5 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,7 +18,7 @@ import ( ) func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) @@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel } } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return usage, nil } @@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn // 直接发送 GeminiChatResponse 响应 err = helper.StringData(c, data) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } info.SendResponseCount++ return true diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 58efa1a5..af5e8233 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * finishReason = constant.FinishReasonToolCalls err = handleStream(c, info, emptyResponse) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } response.ClearToolCalls() @@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * err = handleStream(c, info, response) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } if isStop { _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) @@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := handleFinalStream(c, info, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } //if info.RelayFormat == relaycommon.RelayFormatOpenAI { // helper.Done(c) @@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } @@ -1041,29 +1042,29 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R fullTextResponse.Usage = usage switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: responseBody, err = common.Marshal(fullTextResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: break } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { @@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go index ff9ac678..885a1427 100644 --- a/relay/channel/jimeng/adaptor.go +++ b/relay/channel/jimeng/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil + return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 28af1866..11a0117b 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go index c9db6630..d8b598dc 100644 --- a/relay/channel/jimeng/sign.go +++ b/relay/channel/jimeng/sign.go @@ -12,7 +12,7 @@ import ( "io" "net/http" "net/url" - "one-api/common" + "one-api/logger" "sort" "strings" "time" @@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error { if err != nil { return err } - common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) + logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) payloadHash := sha256.Sum256(body) hexPayloadHash := hex.EncodeToString(payloadHash[:]) c.Set(HexPayloadHashKey, hexPayloadHash) diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index bf318aa7..a383728f 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } return "", errors.New("invalid relay mode") } diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index d0a15b0d..ff9b72ea 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,5 +6,5 @@ import ( ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 45cb3290..f98ff869 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 37db2aec..f9da685f 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "m3e") { suffix = "embeddings" } - fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix) return fullRequestURL, nil } diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 78f96d6d..d91aceb3 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) @@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index d540388d..29004d0c 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } @@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: adaptor := openai.Adaptor{} return adaptor.DoResponse(c, resp, info) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 1f3fda8d..1a0caf75 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == relaycommon.RelayFormatClaude { - return info.BaseUrl + "/v1/chat/completions", nil + if info.RelayFormat == types.RelayFormatClaude { + return info.ChannelBaseUrl + "/v1/chat/completions", nil } switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - return info.BaseUrl + "/api/embed", nil + return info.ChannelBaseUrl + "/api/embed", nil default: - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index d4686ce3..066581fa 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, doResponseBody) + service.IOCopyBytesGracefully(c, resp, doResponseBody) return usage, nil } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index fc1749a0..d783b9d8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { - if strings.HasPrefix(info.BaseUrl, "https://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") + if strings.HasPrefix(info.ChannelBaseUrl, "https://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://") baseUrl = "wss://" + baseUrl - info.BaseUrl = baseUrl - } else if strings.HasPrefix(info.BaseUrl, "http://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "http://") + info.ChannelBaseUrl = baseUrl + } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://") baseUrl = "ws://" + baseUrl - info.BaseUrl = baseUrl + info.ChannelBaseUrl = baseUrl } } switch info.ChannelType { @@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { task = strings.TrimPrefix(task, "messages") task = "chat/completions" + task } @@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { responsesApiVersion := "preview" subUrl := "/openai/v1/responses" - if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") { + if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") { subUrl = "/openai/responses" responsesApiVersion = apiVersion } @@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil } model_ := info.UpstreamModelName @@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) case constant.ChannelTypeCustom: - url := info.BaseUrl + url := info.ChannelBaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil default: - if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini { + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 696c5cb0..2a4b4938 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -7,10 +7,12 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -21,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string info.SendResponseCount++ switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: return handleClaudeFormat(c, data, info) - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: return handleGeminiFormat(c, data, info) } return nil @@ -50,7 +52,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { - common.LogError(c, "failed to unmarshal stream response: "+err.Error()) + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) return err } @@ -63,7 +65,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.LogError(c, "failed to marshal gemini response: "+err.Error()) + logger.LogError(c, "failed to marshal gemini response: "+err.Error()) return err } @@ -110,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - common.SysError("error processing stream response: " + err.Error()) + common.SysLog("error processing stream response: " + err.Error()) } } return nil @@ -146,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -201,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream usage *dto.Usage, containStreamUsage bool) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if info.ShouldIncludeUsage && !containStreamUsage { response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) @@ -209,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } helper.Done(c) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -224,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream _ = helper.ClaudeData(c, *resp) } - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -245,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - common.SysError("error marshalling gemini response: " + err.Error()) + common.SysLog("error marshalling gemini response: " + err.Error()) return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b8e72273..00dde46d 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) model := info.UpstreamModelName var responseId string @@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - common.SysError("error handling stream format: " + err.Error()) + common.SysLog("error handling stream format: " + err.Error()) } } if len(data) > 0 { @@ -143,10 +144,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re shouldSendLastResp := true if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, &containStreamUsage, info, &shouldSendLastResp); err != nil { - common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) + logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } - if info.RelayFormat == relaycommon.RelayFormatOpenAI { + if info.RelayFormat == types.RelayFormatOpenAI { if shouldSendLastResp { _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) } @@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re // 处理token计算 if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - common.LogError(c, "error processing tokens: "+err.Error()) + logger.LogError(c, "error processing tokens: "+err.Error()) } if !containStreamUsage { @@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) @@ -210,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if forceFormat { responseBody, err = common.Marshal(simpleResponse) if err != nil { @@ -219,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { break } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) geminiRespStr, err := common.Marshal(geminiResp) if err != nil { @@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo responseBody = geminiRespStr } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &simpleResponse.Usage, nil } @@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens @@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel c.Writer.WriteHeaderNow() _, err := io.Copy(c.Writer, resp.Body) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // count tokens by audio file duration audioTokens, err := countAudioTokens(c) @@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) usage := &dto.Usage{} usage.PromptTokens = audioTokens @@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken @@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken @@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. localUsage = &dto.RealtimeUsage{} // print now usage } - common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken @@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. case <-targetClosed: case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil - common.LogError(c, "realtime error: "+err.Error()) + logger.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } @@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // Once we've written to the client, we should not return errors anymore // because the upstream has already consumed resources and returned content diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index bae6fcb6..754a6f44 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -16,7 +17,7 @@ import ( ) func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // read response body var responsesResponse dto.OpenAIResponsesResponse @@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // compute usage usage := dto.Usage{} @@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4d1ab783..2a022a1b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 9b8bce7d..3a6ec2f4 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -58,15 +58,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) stopChan <- true return } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -78,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) stopChan <- true return } @@ -96,7 +96,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, responseText } @@ -105,7 +105,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { @@ -133,6 +133,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 92cb08a2..8ab9c854 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 05e6d453..4c176c08 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index 2e37ad15..b21faccb 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { @@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 8d057513..a5ada137 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -76,7 +76,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl // apiKey format: "access_key|secret_key" keyParts := strings.Split(info.ApiKey, "|") diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index b7b9a5ff..1fecda08 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -81,7 +81,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey // apiKey format: "access_key|secret_key" diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 9c04c7ad..df2bb99e 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -59,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - baseURL := info.BaseUrl + baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } @@ -139,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - common.SysError(fmt.Sprintf("Get Task error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index f40b480c..b0cc0bdc 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -86,7 +86,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index b86d8a16..ab96ecaa 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/", info.BaseUrl), nil + return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 78ce6238..f33a275c 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -106,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) continue } @@ -117,17 +117,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } } if err := scanner.Err(); err != nil { - common.SysError("error reading stream: " + err.Error()) + common.SysLog("error reading stream: " + err.Error()) } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil } @@ -138,7 +138,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -156,7 +156,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 2cc4f663..b46cb952 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 6a3a5370..d5671ab2 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf xaiRequest := ImageRequest{ Model: request.Model, Prompt: request.Prompt, - N: request.N, + N: int(request.N), ResponseFormat: request.ResponseFormat, } return xaiRequest, nil @@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d098102..5cae9c0a 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -47,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } @@ -63,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) @@ -74,12 +74,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -101,7 +101,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.IOCopyBytesGracefully(c, resp, encodeJson) + service.IOCopyBytesGracefully(c, resp, encodeJson) return xaiResponse.Usage, nil } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 1a426d50..9d5c190f 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -143,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap if err != nil || resp.StatusCode != 101 { return nil, nil, err } + + defer func() { + conn.Close() + }() + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -218,20 +223,19 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { - err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + common.SysLog("error closing websocket connection: " + err.Error()) } break } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index e3be0e8e..bd27c90b 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsStream { method = "sse-invoke" } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 35882ed5..8eb0dcc1 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -10,6 +10,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" + "one-api/service" "one-api/types" "strings" "sync" @@ -38,7 +39,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + common.SysLog("invalid zhipu key: " + apikey) return "" } @@ -186,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -195,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage @@ -212,7 +213,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } @@ -222,7 +223,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index a83e30e6..0fae3767 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl) + baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl) switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return fmt.Sprintf("%s/embeddings", baseUrl), nil diff --git a/relay/claude_handler.go b/relay/claude_handler.go index b4bf78ff..8f846f1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" @@ -18,68 +17,26 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { - textRequest = &dto.ClaudeRequest{} - err = c.ShouldBindJSON(textRequest) - if err != nil { - return nil, err - } - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - if textRequest.Model == "" { - return nil, errors.New("field model is required") - } - return textRequest, nil -} +func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoClaude(c) + textRequest, ok := info.Request.(*dto.ClaudeRequest) - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateClaudeRequest(c) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request)) } - if textRequest.Stream { - relayInfo.IsStream = true - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) if textRequest.MaxTokens == 0 { textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) @@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { textRequest.Temperature = common.GetPointer[float64](1.0) } textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") - relayInfo.UpstreamModelName = textRequest.Model + info.UpstreamModelName = textRequest.Model } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) //log.Printf("usage: %v", usage) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + + service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage)) return nil } - -func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5cd9223b..51142ff8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,10 +1,13 @@ package common import ( + "errors" + "fmt" "one-api/common" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" + "one-api/types" "strings" "time" @@ -33,17 +36,6 @@ type ClaudeConvertInfo struct { Done bool } -const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" - RelayFormatOpenAIResponses = "openai_responses" - RelayFormatOpenAIAudio = "openai_audio" - RelayFormatOpenAIImage = "openai_image" - RelayFormatRerank = "rerank" - RelayFormatEmbedding = "embedding" -) - type RerankerInfo struct { Documents []any ReturnDocuments bool @@ -59,61 +51,182 @@ type ResponsesUsageInfo struct { BuiltInTools map[string]*BuildInToolInfo } -type RelayInfo struct { +type ChannelMeta struct { ChannelType int ChannelId int - ChannelIsMultiKey bool // 是否多密钥 - ChannelMultiKeyIndex int // 多密钥索引 - TokenId int - TokenKey string - UserId int - UsingGroup string // 使用的分组 - UserGroup string // 用户所在分组 - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - isFirstResponse bool + ChannelIsMultiKey bool + ChannelMultiKeyIndex int + ChannelBaseUrl string + ApiType int + ApiVersion string + ApiKey string + Organization string + ChannelCreateTime int64 + ParamOverride map[string]interface{} + ChannelSetting dto.ChannelSettings + ChannelOtherSettings dto.ChannelOtherSettings + UpstreamModelName string + IsModelMapped bool + SupportStreamOptions bool // 是否支持流式选项 +} + +type RelayInfo struct { + TokenId int + TokenKey string + UserId int + UsingGroup string // 使用的分组 + UserGroup string // 用户所在分组 + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + isFirstResponse bool //SendLastReasoningResponse bool - ApiType int IsStream bool IsGeminiBatchEmbedding bool IsPlayground bool UsePrice bool RelayMode int - UpstreamModelName string OriginModelName string - //RecodeModelName string - RequestURLPath string - ApiVersion string - PromptTokens int - ApiKey string - Organization string - BaseUrl string - SupportStreamOptions bool - ShouldIncludeUsage bool - DisablePing bool // 是否禁止向下游发送自定义 Ping - IsModelMapped bool - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - ChannelSetting dto.ChannelSettings - ChannelOtherSettings dto.ChannelOtherSettings - ParamOverride map[string]interface{} - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat string - SendResponseCount int - ChannelCreateTime int64 + RequestURLPath string + PromptTokens int + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 + + PriceData types.PriceData + + Request dto.Request + ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo *ResponsesUsageInfo + *ChannelMeta +} + +func (info *RelayInfo) InitChannelMeta(c *gin.Context) { + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + apiType, _ := common.ChannelType2APIType(channelType) + channelMeta := &ChannelMeta{ + ChannelType: channelType, + ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId), + ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), + ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + IsModelMapped: false, + SupportStreamOptions: false, + } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + channelMeta.ChannelSetting = channelSetting + } + + channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) + if ok { + channelMeta.ChannelOtherSettings = channelOtherSettings + } + + if streamSupportedChannels[channelMeta.ChannelType] { + channelMeta.SupportStreamOptions = true + } + info.ChannelMeta = channelMeta +} + +func (info *RelayInfo) ToString() string { + if info == nil { + return "RelayInfo" + } + + // Basic info + b := &strings.Builder{} + fmt.Fprintf(b, "RelayInfo{ ") + fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat) + fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode) + fmt.Fprintf(b, "IsStream: %t, ", info.IsStream) + fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground) + fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath) + fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName) + fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens) + fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage) + fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing) + fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount) + fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota) + + // User & token info (mask secrets) + fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ", + info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota) + fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited) + + // Time info + latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds() + fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ", + info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs) + + // Audio / realtime + if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage { + fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ", + info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools)) + } + + // Reasoning + if info.ReasoningEffort != "" { + fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort) + } + + // Price data (non-sensitive) + if info.PriceData.UsePrice { + fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting()) + } + + // Channel metadata (mask ApiKey) + if info.ChannelMeta != nil { + cm := info.ChannelMeta + fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ", + cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions) + } + + // Responses usage info (non-sensitive) + if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 { + fmt.Fprintf(b, "ResponsesTools{ ") + first := true + for name, tool := range info.ResponsesUsageInfo.BuiltInTools { + if !first { + fmt.Fprintf(b, ", ") + } + first = false + if tool != nil { + fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount) + } else { + fmt.Fprintf(b, "%s: calls=0", name) + } + } + fmt.Fprintf(b, " }, ") + } + + fmt.Fprintf(b, "}") + return b.String() } // 定义支持流式选项的通道类型 @@ -132,7 +245,8 @@ var streamSupportedChannels = map[int]bool{ } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { - info := GenRelayInfo(c) + info := genBaseRelayInfo(c, nil) + info.RelayFormat = types.RelayFormatOpenAIRealtime info.ClientWs = ws info.InputAudioFormat = "pcm16" info.OutputAudioFormat = "pcm16" @@ -140,9 +254,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } -func GenRelayInfoClaude(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatClaude +func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatClaude info.ShouldIncludeUsage = false info.ClaudeConvertInfo = &ClaudeConvertInfo{ LastMessagesType: LastMessageTypeNone, @@ -150,41 +264,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { return info } -func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeRerank - info.RelayFormat = RelayFormatRerank + info.RelayFormat = types.RelayFormatRerank info.RerankerInfo = &RerankerInfo{ - Documents: req.Documents, - ReturnDocuments: req.GetReturnDocuments(), + Documents: request.Documents, + ReturnDocuments: request.GetReturnDocuments(), } return info } -func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIAudio +func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIAudio return info } -func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatEmbedding +func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatEmbedding return info } -func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeResponses - info.RelayFormat = RelayFormatOpenAIResponses + info.RelayFormat = types.RelayFormatOpenAIResponses info.SupportStreamOptions = false info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } - if len(req.Tools) > 0 { - for _, tool := range req.Tools { + if len(request.Tools) > 0 { + for _, tool := range request.Tools { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ ToolName: toolType, @@ -200,104 +314,82 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel } } } - info.IsStream = req.Stream return info } -func GenRelayInfoGemini(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatGemini +func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatGemini info.ShouldIncludeUsage = false + return info } -func GenRelayInfoImage(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIImage +func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIImage return info } -func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) - channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) +func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAI + return info +} + +func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { + + //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) - tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) - tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) - userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) - tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) if startTime.IsZero() { startTime = time.Now() } + + isStream := false + + if request != nil { + isStream = request.IsStream(c) + } + // firstResponseTime = time.Now() - 1 second - apiType, _ := common.ChannelType2APIType(channelType) - info := &RelayInfo{ - UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), - isFirstResponse: true, - RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - TokenKey: tokenKey, - UserId: userId, - UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), - UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), - TokenUnlimited: tokenUnlimited, + Request: request, + + UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), + + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), + + TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), + TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), + TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), + + isFirstResponse: true, + RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), + RequestURLPath: c.Request.URL.String(), + IsStream: isStream, + StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), - Organization: c.GetString("channel_organization"), - - ChannelCreateTime: c.GetInt64("channel_create_time"), - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, }, - - ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), - ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") info.RequestURLPath = "/v1" + info.RequestURLPath } - if info.BaseUrl == "" { - info.BaseUrl = constant.ChannelBaseURLs[channelType] - } - if info.ChannelType == constant.ChannelTypeAzure { - info.ApiVersion = GetAPIVersion(c) - } - if info.ChannelType == constant.ChannelTypeVertexAi { - info.ApiVersion = c.GetString("region") - } - if streamSupportedChannels[info.ChannelType] { - info.SupportStreamOptions = true - } - - channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) - if ok { - info.ChannelSetting = channelSetting - } - - channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) - if ok { - info.ChannelOtherSettings = channelOtherSettings - } userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) if ok { @@ -307,12 +399,43 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { return info } -func (info *RelayInfo) SetPromptTokens(promptTokens int) { - info.PromptTokens = promptTokens +func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { + switch relayFormat { + case types.RelayFormatOpenAI: + return GenRelayInfoOpenAI(c, request), nil + case types.RelayFormatOpenAIAudio: + return GenRelayInfoOpenAIAudio(c, request), nil + case types.RelayFormatOpenAIImage: + return GenRelayInfoImage(c, request), nil + case types.RelayFormatOpenAIRealtime: + return GenRelayInfoWs(c, ws), nil + case types.RelayFormatClaude: + return GenRelayInfoClaude(c, request), nil + case types.RelayFormatRerank: + if request, ok := request.(*dto.RerankRequest); ok { + return GenRelayInfoRerank(c, request), nil + } + return nil, errors.New("request is not a RerankRequest") + case types.RelayFormatGemini: + return GenRelayInfoGemini(c, request), nil + case types.RelayFormatEmbedding: + return GenRelayInfoEmbedding(c, request), nil + case types.RelayFormatOpenAIResponses: + if request, ok := request.(*dto.OpenAIResponsesRequest); ok { + return GenRelayInfoResponses(c, request), nil + } + return nil, errors.New("request is not a OpenAIResponsesRequest") + case types.RelayFormatTask: + return genBaseRelayInfo(c, nil), nil + case types.RelayFormatMjProxy: + return genBaseRelayInfo(c, nil), nil + default: + return nil, errors.New("invalid relay format") + } } -func (info *RelayInfo) SetIsStream(isStream bool) { - info.IsStream = isStream +func (info *RelayInfo) SetPromptTokens(promptTokens int) { + info.PromptTokens = promptTokens } func (info *RelayInfo) SetFirstResponseTime() { @@ -334,11 +457,15 @@ type TaskRelayInfo struct { ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - info := &TaskRelayInfo{ - RelayInfo: GenRelayInfo(c), +func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) { + relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return nil, err } - return info + info := &TaskRelayInfo{ + RelayInfo: relayInfo, + } + return info, nil } type TaskSubmitReq struct { diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 57df5fe3..05dbfa6d 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println("reranker response body: ", string(responseBody)) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fef8d2c9..99f0d817 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/types" @@ -16,69 +15,27 @@ import ( "github.com/gin-gonic/gin" ) -func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) - return token -} +func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { -func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error { - if embeddingRequest.Input == nil { - return fmt.Errorf("input is empty") - } - if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { - embeddingRequest.Model = "omni-moderation-latest" - } - if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { - embeddingRequest.Model = c.Param("model") - } - return nil -} + info.InitChannelMeta(c) -func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoEmbedding(c) - - var embeddingRequest *dto.EmbeddingRequest - err := common.UnmarshalBodyReusable(c, &embeddingRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request)) } - err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) + err := helper.ModelMappedHelper(c, info, embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getEmbeddingPromptToken(*embeddingRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e0581156..d50fff42 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,17 +2,16 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,64 +19,6 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if len(request.Contents) == 0 { - return nil, errors.New("contents is required") - } - return request, nil -} - -// 流模式 -// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx -func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { - if c.Query("alt") == "sse" { - relayInfo.IsStream = true - } - - // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { - // relayInfo.IsStream = true - // } -} - -func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) { - var inputTexts []string - for _, content := range textRequest.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - if len(inputTexts) == 0 { - return nil, nil - } - - sensitiveWords, err := service.CheckSensitiveInput(inputTexts) - return sensitiveWords, err -} - -func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int { - // 计算输入 token 数量 - var inputTexts []string - for _, content := range req.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - - inputText := strings.Join(inputTexts, "\n") - inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) - info.PromptTokens = inputTokens - return inputTokens -} - func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget @@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string { return modelName } -func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateGeminiRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } +func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - relayInfo := relaycommon.GenRelayInfoGemini(c) - - // 检查 Gemini 流式模式 - checkGeminiStreamMode(c, relayInfo) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } + request, ok := info.Request.(*dto.GeminiChatRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request)) } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getGeminiInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if isNoThinkingRequest(req) { + if isNoThinkingRequest(request) { // check is thinking - if !strings.Contains(relayInfo.OriginModelName, "-nothinking") { + if !strings.Contains(info.OriginModelName, "-nothinking") { // try to get no thinking model price - noThinkingModelName := relayInfo.OriginModelName + "-nothinking" + noThinkingModelName := info.OriginModelName + "-nothinking" containPrice := helper.ContainPriceOrRatio(noThinkingModelName) if containPrice { - relayInfo.OriginModelName = noThinkingModelName - relayInfo.UpstreamModelName = noThinkingModelName + info.OriginModelName = noThinkingModelName + info.UpstreamModelName = noThinkingModelName } } } - if req.GenerationConfig.ThinkingConfig == nil { - gemini.ThinkingAdaptor(req, relayInfo) + if request.GenerationConfig.ThinkingConfig == nil { + gemini.ThinkingAdaptor(request, info) } } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) // Clean up empty system instruction - if req.SystemInstructions != nil { + if request.SystemInstructions != nil { hasContent := false - for _, part := range req.SystemInstructions.Parts { + for _, part := range request.SystemInstructions.Parts { if part.Text != "" { hasContent = true break } } if !hasContent { - req.SystemInstructions = nil + request.SystemInstructions = nil } } var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { requestBody = bytes.NewReader(body) } else { // 使用 ConvertGeminiRequest 转换请求格式 - convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req) + convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("Gemini request body: %s", string(jsonData)) - } + logger.LogDebug(c, "Gemini request body: "+string(jsonData)) + requestBody = bytes.NewReader(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } -func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoGemini(c) +func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") - relayInfo.IsGeminiBatchEmbedding = isBatch + info.IsGeminiBatchEmbedding = isBatch - var promptTokens int var req any var err error var inputTexts []string @@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } } - promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName) - relayInfo.SetPromptTokens(promptTokens) - c.Set("prompt_tokens", promptTokens) - err = helper.ModelMappedHelper(c, relayInfo, req) + err = helper.ModelMappedHelper(c, info, req) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader jsonData, err := common.Marshal(req) @@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(jsonData) - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) + logger.LogError(c, "Do gemini request failed: "+err.Error()) return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index c8edb798..4b2c51eb 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "github.com/gin-gonic/gin" @@ -100,7 +101,7 @@ func Done(c *gin.Context) { func WssString(c *gin.Context, ws *websocket.Conn, str string) error { if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) @@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { return fmt.Errorf("error marshalling object: %w", err) } if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) @@ -121,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { } func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { + if ws == nil { + return + } errorObj := &dto.RealtimeEvent{ Type: "error", EventId: GetLocalRealtimeID(c), diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index c1735149..e894e228 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,9 +4,10 @@ import ( "encoding/json" "errors" "fmt" - common2 "one-api/common" "one-api/dto" + common2 "one-api/logger" "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro } if request != nil { switch info.RelayFormat { - case common.RelayFormatGemini: + case types.RelayFormatGemini: // Gemini 模型映射 - case common.RelayFormatClaude: + case types.RelayFormatClaude: if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { claudeRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIResponses: + case types.RelayFormatOpenAIResponses: if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { openAIResponsesRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIAudio: + case types.RelayFormatOpenAIAudio: if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { openAIAudioRequest.Model = info.UpstreamModelName } - case common.RelayFormatOpenAIImage: + case types.RelayFormatOpenAIImage: if imageRequest, ok := request.(*dto.ImageRequest); ok { imageRequest.Model = info.UpstreamModelName } - case common.RelayFormatRerank: + case types.RelayFormatRerank: if rerankRequest, ok := request.(*dto.RerankRequest); ok { rerankRequest.Model = info.UpstreamModelName } - case common.RelayFormatEmbedding: + case types.RelayFormatEmbedding: if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { embeddingRequest.Model = info.UpstreamModelName } diff --git a/relay/helper/price.go b/relay/helper/price.go index e80578e5..fdc5b66d 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,35 +5,14 @@ import ( "one-api/common" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" + "one-api/types" "github.com/gin-gonic/gin" ) -type GroupRatioInfo struct { - GroupRatio float64 - GroupSpecialRatio float64 - HasSpecialRatio bool -} - -type PriceData struct { - ModelPrice float64 - ModelRatio float64 - CompletionRatio float64 - CacheRatio float64 - CacheCreationRatio float64 - ImageRatio float64 - UsePrice bool - ShouldPreConsumedQuota int - GroupRatioInfo GroupRatioInfo -} - -func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) -} - // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present -func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { - groupRatioInfo := GroupRatioInfo{ +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo { + groupRatioInfo := types.GroupRatioInfo{ GroupRatio: 1.0, // default ratio GroupSpecialRatio: -1, } @@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR return groupRatioInfo } -func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) { modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false) groupRatioInfo := HandleGroupRatio(c, info) @@ -74,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var imageRatio float64 var cacheCreationRatio float64 if !usePrice { - preConsumedTokens := common.PreConsumedQuota - if maxTokens != 0 { - preConsumedTokens = promptTokens + maxTokens + preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) + if meta.MaxTokens != 0 { + preConsumedTokens += meta.MaxTokens } var success bool var matchName string @@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens acceptUnsetRatio = true } if !acceptUnsetRatio { - return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) + return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) } } completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) @@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { + if meta.ImagePriceRatio != 0 { + modelPrice = modelPrice * meta.ImagePriceRatio + } preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } - priceData := PriceData{ + priceData := types.PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, @@ -115,18 +97,12 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens if common.DebugEnabled { println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting())) } - + info.PriceData = priceData return priceData, nil } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -140,7 +116,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCal } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := PerCallPriceData{ + priceData := types.PerCallPriceData{ ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a5706f95..725d178c 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/setting/operation_setting" "strings" @@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-done: case <-time.After(5 * time.Second): - common.LogError(c, "timeout waiting for goroutines to exit") + logger.LogError(c, "timeout waiting for goroutines to exit") } close(stopChan) @@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) common.SafeSendBool(stopChan, true) } if common.DebugEnabled { @@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case err := <-done: if err != nil { - common.LogError(c, "ping data error: "+err.Error()) + logger.LogError(c, "ping data error: "+err.Error()) return } if common.DebugEnabled { println("ping data sent") } case <-time.After(10 * time.Second): - common.LogError(c, "ping data send timeout") + logger.LogError(c, "ping data send timeout") return case <-ctx.Done(): return @@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon // 监听客户端断开连接 return case <-pingTimeout.C: - common.LogError(c, "ping goroutine max duration reached") + logger.LogError(c, "ping goroutine max duration reached") return } } @@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) } common.SafeSendBool(stopChan, true) if common.DebugEnabled { @@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } case <-time.After(10 * time.Second): - common.LogError(c, "data handler timeout") + logger.LogError(c, "data handler timeout") return case <-ctx.Done(): return @@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon if err := scanner.Err(); err != nil { if err != io.EOF { - common.LogError(c, "scanner error: "+err.Error()) + logger.LogError(c, "scanner error: "+err.Error()) } } }) @@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-ticker.C: // 超时处理逻辑 - common.LogError(c, "streaming timeout") + logger.LogError(c, "streaming timeout") case <-stopChan: // 正常结束 - common.LogInfo(c, "streaming finished") + logger.LogInfo(c, "streaming finished") case <-c.Request.Context().Done(): // 客户端断开连接 - common.LogInfo(c, "client disconnected") + logger.LogInfo(c, "client disconnected") } } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go new file mode 100644 index 00000000..1d556a33 --- /dev/null +++ b/relay/helper/valid_request.go @@ -0,0 +1,301 @@ +package helper + +import ( + "errors" + "fmt" + "math" + "one-api/common" + "one-api/dto" + "one-api/logger" + relayconstant "one-api/relay/constant" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) + + switch format { + case types.RelayFormatOpenAI: + request, err = GetAndValidateTextRequest(c, relayMode) + case types.RelayFormatGemini: + request, err = GetAndValidateGeminiRequest(c) + case types.RelayFormatClaude: + request, err = GetAndValidateClaudeRequest(c) + case types.RelayFormatOpenAIResponses: + request, err = GetAndValidateResponsesRequest(c) + + case types.RelayFormatOpenAIImage: + request, err = GetAndValidOpenAIImageRequest(c, relayMode) + case types.RelayFormatEmbedding: + request, err = GetAndValidateEmbeddingRequest(c, relayMode) + case types.RelayFormatRerank: + request, err = GetAndValidateRerankRequest(c) + case types.RelayFormatOpenAIAudio: + request, err = GetAndValidAudioRequest(c, relayMode) + case types.RelayFormatOpenAIRealtime: + request = &dto.BaseRequest{} + default: + return nil, fmt.Errorf("unsupported relay format: %s", format) + } + return request, err +} + +func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err + } + switch relayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + default: + err = c.Request.ParseForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + if audioRequest.Model == "" { + audioRequest.Model = formData.Get("model") + } + + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + audioRequest.ResponseFormat = formData.Get("response_format") + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { + var rerankRequest *dto.RerankRequest + err := common.UnmarshalBodyReusable(c, &rerankRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if rerankRequest.Query == "" { + return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + if len(rerankRequest.Documents) == 0 { + return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + return rerankRequest, nil +} + +func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if embeddingRequest.Input == nil { + return nil, fmt.Errorf("input is empty") + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "omni-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + return embeddingRequest, nil +} + +func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if request.Input == nil { + return nil, errors.New("input is required") + } + return request, nil +} + +func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + + switch relayMode { + case relayconstant.RelayModeImagesEdits: + _, err := c.MultipartForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = uint(common.String2Int(formData.Get("n"))) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") + + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + + watermark := formData.Has("watermark") + if watermark { + imageRequest.Watermark = &watermark + } + default: + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-3" + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + + if imageRequest.N == 0 { + imageRequest.N = 1 + } + } + + return imageRequest, nil +} + +func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + + //if textRequest.Stream { + // relayInfo.IsStream = true + //} + + return textRequest, nil +} + +func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { + textRequest := &dto.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + + if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + + if textRequest.MaxTokens > math.MaxInt32/2 { + return nil, errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return nil, errors.New("model is required") + } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } + switch relayMode { + case relayconstant.RelayModeCompletions: + if textRequest.Prompt == "" { + return nil, errors.New("field prompt is required") + } + case relayconstant.RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + case relayconstant.RelayModeEmbeddings: + case relayconstant.RelayModeModerations: + if textRequest.Input == nil || textRequest.Input == "" { + return nil, errors.New("field input is required") + } + case relayconstant.RelayModeEdits: + if textRequest.Instruction == "" { + return nil, errors.New("field instruction is required") + } + } + return textRequest, nil +} + +func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { + + request := &dto.GeminiChatRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + + //if c.Query("alt") == "sse" { + // relayInfo.IsStream = true + //} + + return request, nil +} diff --git a/relay/image_handler.go b/relay/image_handler.go index f0b69699..008a979d 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -3,19 +3,15 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" - "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -23,183 +19,41 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { - imageRequest := &dto.ImageRequest{} +func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - switch info.RelayMode { - case relayconstant.RelayModeImagesEdits: - _, err := c.MultipartForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - imageRequest.Prompt = formData.Get("prompt") - imageRequest.Model = formData.Get("model") - imageRequest.N = common.String2Int(formData.Get("n")) - imageRequest.Quality = formData.Get("quality") - imageRequest.Size = formData.Get("size") + info.InitChannelMeta(c) - if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } + imageRequest, ok := info.Request.(*dto.ImageRequest) - if info.ApiType == constant.APITypeVolcEngine { - watermark := formData.Has("watermark") - imageRequest.Watermark = &watermark - } - default: - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" - } - - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "auto" - } - } - - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request)) } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - return imageRequest, nil -} - -func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoImage(c) - - imageRequest, err := getAndValidImageRequest(c, relayInfo) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, imageRequest) + err := helper.ModelMappedHelper(c, info, imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - var preConsumedQuota int - var quota int - var userQuota int - if !priceData.UsePrice { - // modelRatio 16 = modelPrice $0.04 - // per 1 modelRatio = $0.04 / 16 - // priceData.ModelPrice = 0.0025 * priceData.ModelRatio - preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - } else { - sizeRatio := 1.0 - qualityRatio := 1.0 - - if strings.HasPrefix(imageRequest.Model, "dall-e") { - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 0.4 - } else if imageRequest.Size == "512x512" { - sizeRatio = 0.45 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1 - } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2 - } - - if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { - qualityRatio = 2.0 - if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - qualityRatio = 1.5 - } - } - } - - // reset model price - priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) - userQuota, err = model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) - } - } - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) + convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { + if info.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) @@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) @@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } if usage.(*dto.Usage).TotalTokens == 0 { - usage.(*dto.Usage).TotalTokens = imageRequest.N + usage.(*dto.Usage).TotalTokens = int(imageRequest.N) } if usage.(*dto.Usage).PromptTokens == 0 { - usage.(*dto.Usage).PromptTokens = imageRequest.N + usage.(*dto.Usage).PromptTokens = int(imageRequest.N) } + quality := "standard" if imageRequest.Quality == "hd" { quality = "hd" } - logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent) + var logContent string + + if len(imageRequest.Size) > 0 { + logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) + } + + postConsumeQuota(c, info, usage.(*dto.Usage), logContent) return nil } diff --git a/relay/relay-mj.go b/relay/mjproxy_handler.go similarity index 87% rename from relay/relay-mj.go rename to relay/mjproxy_handler.go index e7f316b9..756ad450 100644 --- a/relay/relay-mj.go +++ b/relay/mjproxy_handler.go @@ -170,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { - startTime := time.Now().UnixNano() / int64(time.Millisecond) - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - //group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { var swapFaceRequest dto.SwapFaceRequest err := common.UnmarshalBodyReusable(c, &swapFaceRequest) if err != nil { @@ -187,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + priceData := helper.ModelPriceHelperPerCall(c, info) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -212,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) + err := service.PostConsumeQuota(info, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, - TokenId: tokenId, - UserQuota: userQuota, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) } }() midjResponse := &mjResp.Response midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: info.UserId, Code: midjResponse.Code, Action: constant.MjActionSwapFace, MjId: midjResponse.Result, @@ -245,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { PromptEn: "", Description: midjResponse.Description, State: "", - SubmitTime: startTime, + SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), StartTime: time.Now().UnixNano() / int64(time.Millisecond), FinishTime: 0, ImageUrl: "", @@ -300,7 +293,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - common.IOCopyBytesGracefully(c, nil, respBody) + service.IOCopyBytesGracefully(c, nil, respBody) return nil } @@ -369,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - - //tokenId := c.GetInt("token_id") - //channelType := c.GetInt("channel") - userId := c.GetInt("id") - group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { consumeQuota := true var midjRequest dto.MidjourneyRequest err := common.UnmarshalBodyReusable(c, &midjRequest) @@ -384,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") } - if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } - relayMode = relayconstant.RelayModeMidjourneyChange + relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange } - if relayMode == relayconstant.RelayModeMidjourneyVideo { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo } - if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = constant.MjActionImagine - } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = constant.MjActionDescribe - } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 midjRequest.Action = constant.MjActionEdits - } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten - } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend - } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" - if relayMode == relayconstant.RelayModeMidjourneyChange { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { @@ -422,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } //action = midjRequest.Action mjId = midjRequest.TaskId - } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } @@ -432,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } mjId = params.TaskId midjRequest.Action = params.Action - } else if relayMode == relayconstant.RelayModeMidjourneyModal { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { //if midjRequest.MaskBase64 == "" { // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") //} mjId = midjRequest.TaskId midjRequest.Action = constant.MjActionModal - } else if relayMode == relayconstant.RelayModeMidjourneyVideo { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") @@ -448,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = midjRequest.TaskId } - originTask := model.GetByMJId(userId, mjId) + originTask := model.GetByMJId(relayInfo.UserId, mjId) if originTask == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 if setting.MjActionCheckSuccessEnabled { - if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } } @@ -496,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons priceData := helper.ModelPriceHelperPerCall(c, relayInfo) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -521,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + ChannelId: relayInfo.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, - Group: group, + Group: relayInfo.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) } }() @@ -550,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} // other: 提交错误,description为错误描述 midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: relayInfo.UserId, Code: midjResponse.Code, Action: midjRequest.Action, MjId: midjResponse.Result, @@ -572,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - common.SysError("get_channel_null: " + err.Error()) + common.SysLog("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/relay-text.go b/relay/relay-text.go index 50d574f3..1e5aafd6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -2,144 +2,48 @@ package relay import ( "bytes" - "errors" "fmt" "io" - "math" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/setting/operation_setting" "one-api/types" "strings" "time" - "github.com/bytedance/gopkg/util/gopool" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) -func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { - textRequest := &dto.GeneralOpenAIRequest{} - err := common.UnmarshalBodyReusable(c, textRequest) - if err != nil { - return nil, err - } - if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } +func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - if textRequest.MaxTokens > math.MaxInt32/2 { - return nil, errors.New("max_tokens is invalid") - } - if textRequest.Model == "" { - return nil, errors.New("model is required") - } - if textRequest.WebSearchOptions != nil { - if textRequest.WebSearchOptions.SearchContextSize != "" { - validSizes := map[string]bool{ - "high": true, - "medium": true, - "low": true, - } - if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { - return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") - } - } else { - textRequest.WebSearchOptions.SearchContextSize = "medium" - } - } - switch relayInfo.RelayMode { - case relayconstant.RelayModeCompletions: - if textRequest.Prompt == "" { - return nil, errors.New("field prompt is required") - } - case relayconstant.RelayModeChatCompletions: - if len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - case relayconstant.RelayModeEmbeddings: - case relayconstant.RelayModeModerations: - if textRequest.Input == nil || textRequest.Input == "" { - return nil, errors.New("field input is required") - } - case relayconstant.RelayModeEdits: - if textRequest.Instruction == "" { - return nil, errors.New("field instruction is required") - } - } - relayInfo.IsStream = textRequest.Stream - return textRequest, nil -} + info.InitChannelMeta(c) -func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { + textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest) - relayInfo := relaycommon.GenRelayInfo(c) - - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + if !ok { + //return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request) } if textRequest.WebSearchOptions != nil { c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) } - if setting.ShouldCheckPromptSensitive() { - words, err := checkRequestSensitive(textRequest, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) + err := helper.ModelMappedHelper(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - // 获取 promptTokens,如果上下文中已经存在,则直接使用 - var promptTokens int - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens = value.(int) - relayInfo.PromptTokens = promptTokens - } else { - promptTokens, err = getPromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) - } - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newApiErr != nil { - return newApiErr - } - defer func() { - if newApiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() includeUsage := true // 判断用户是否需要返回使用情况 if textRequest.StreamOptions != nil { @@ -147,7 +51,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // 如果不支持StreamOptions,将StreamOptions设置为nil - if !relayInfo.SupportStreamOptions || !textRequest.Stream { + if !info.SupportStreamOptions || !textRequest.Stream { textRequest.StreamOptions = nil } else { // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions @@ -158,16 +62,16 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - relayInfo.ShouldIncludeUsage = includeUsage + info.ShouldIncludeUsage = includeUsage - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -177,12 +81,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - if relayInfo.ChannelSetting.SystemPrompt != "" { + if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 request := convertedRequest.(*dto.GeneralOpenAIRequest) containSystemPrompt := false @@ -196,22 +100,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // 如果没有系统提示,则添加系统提示 systemMessage := dto.Message{ Role: request.GetSystemRoleName(), - Content: relayInfo.ChannelSetting.SystemPrompt, + Content: info.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) - } else if relayInfo.ChannelSetting.SystemPromptOverride { + } else if info.ChannelSetting.SystemPromptOverride { common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) // 如果有系统提示,且允许覆盖,则拼接到前面 for i, message := range request.Messages { if message.Role == request.GetSystemRoleName() { if message.IsStringContent() { - request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) } else { contents := message.ParseContent() contents = append([]dto.MediaContent{ { Type: dto.ContentTypeText, - Text: relayInfo.ChannelSetting.SystemPrompt, + Text: info.ChannelSetting.SystemPrompt, }, }, contents...) request.Messages[i].Content = contents @@ -228,10 +132,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -240,14 +144,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - if common.DebugEnabled { - println("requestBody: ", string(jsonData)) - } + logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) + requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -256,125 +159,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr = service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } - usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenChatRequest(info, *textRequest) - case relayconstant.RelayModeCompletions: - promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) - case relayconstant.RelayModeModerations: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - case relayconstant.RelayModeEmbeddings: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - default: - err = errors.New("unknown relay mode") - promptTokens = 0 - } - info.PromptTokens = promptTokens - return promptTokens, err -} - -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { - var err error - var words []string - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - words, err = service.CheckSensitiveMessages(textRequest.Messages) - case relayconstant.RelayModeCompletions: - words, err = service.CheckSensitiveInput(textRequest.Prompt) - case relayconstant.RelayModeModerations: - words, err = service.CheckSensitiveInput(textRequest.Input) - case relayconstant.RelayModeEmbeddings: - words, err = service.CheckSensitiveInput(textRequest.Input) - } - return words, err -} - -// 预扣费并返回用户剩余配额 -func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - relayInfo.UserQuota = userQuota - if userQuota > 100*preConsumedQuota { - // 用户额度充足,判断令牌额度是否充足 - if !relayInfo.TokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) - } - } - - if preConsumedQuota > 0 { - err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) - } - } - return preConsumedQuota, userQuota, nil -} - -func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { - if preConsumedQuota != 0 { - gopool.Go(func() { - relayInfoCopy := *relayInfo - - err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) - if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) - } - }) - } -} - -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -392,12 +201,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - cacheRatio := priceData.CacheRatio - imageRatio := priceData.ImageRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice + completionRatio := relayInfo.PriceData.CompletionRatio + cacheRatio := relayInfo.PriceData.CacheRatio + imageRatio := relayInfo.PriceData.ImageRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -470,7 +279,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, var audioInputQuota decimal.Decimal var audioInputPrice float64 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens var cachedTokensWithRatio decimal.Decimal @@ -518,11 +327,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, totalTokens := promptTokens + completionTokens var logContent string - if !priceData.UsePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } // record all the consume log even if quota is 0 if totalTokens == 0 { @@ -530,8 +334,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { if !ratio.IsZero() && quota == 0 { quota = 1 @@ -540,11 +344,28 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta))) + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -560,7 +381,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio @@ -604,7 +425,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, diff --git a/relay/relay_task.go b/relay/relay_task.go index 0ccc3b33..95b8083b 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -27,7 +27,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if platform == "" { platform = GetTaskPlatform(c) } - relayInfo := relaycommon.GenTaskRelayInfo(c) + + relayInfo, err := relaycommon.GenTaskRelayInfo(c) + if err != nil { + return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) + } adaptor := GetTaskAdaptor(platform) if adaptor == nil { @@ -97,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.BaseUrl = channel.GetBaseURL() + relayInfo.ChannelBaseUrl = channel.GetBaseURL() relayInfo.ChannelId = originTask.ChannelId } } @@ -127,7 +131,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") @@ -149,7 +153,6 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, Group: relayInfo.UsingGroup, Other: other, }) diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 1e547e2a..85e4f174 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) { +func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { - var rerankRequest *dto.RerankRequest - err := common.UnmarshalBodyReusable(c, &rerankRequest) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + rerankRequest, ok := info.Request.(*dto.RerankRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) - - if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) + err := helper.ModelMappedHelper(c, info, rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getRerankPromptToken(*rerankRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) + convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = common.Marshal(reqMap) @@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError requestBody = bytes.NewBuffer(jsonData) } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 65c240b2..cd80da33 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -3,7 +3,6 @@ package relay import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -12,7 +11,6 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,82 +18,24 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { - request := &dto.OpenAIResponsesRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if request.Model == "" { - return nil, errors.New("model is required") - } - if len(request.Input) == 0 { - return nil, errors.New("input is required") - } - return request, nil +func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) -} - -func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) - return sensitiveWords, err -} - -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { - inputTokens := service.CountTokenInput(req.Input, req.Model) - info.PromptTokens = inputTokens - return inputTokens -} - -func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateResponsesRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + request, ok := info.Request.(*dto.OpenAIResponsesRequest) + if !ok { + common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request)) } - relayInfo := relaycommon.GenRelayInfoResponses(c, req) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkInputSensitive(req, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, req) + err := helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) @@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override - if len(relayInfo.ParamOverride) > 0 { + if len(info.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } - for key, value := range relayInfo.ParamOverride { + for key, value := range info.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) @@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } diff --git a/relay/websocket.go b/relay/websocket.go index 3715b237..2d313154 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/service" "one-api/types" @@ -12,65 +11,35 @@ import ( "github.com/gorilla/websocket" ) -func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoWs(c, ws) +func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - // get & validate textRequest 获取并验证文本请求 - //realtimeEvent, err := getAndValidateWssRequest(c, ws) - //if err != nil { - // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error())) - // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) - //} - - err := helper.ModelMappedHelper(c, relayInfo, nil) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) //var requestBody io.Reader //firstWssRequest, _ := c.Get("first_wss_request") //requestBody = bytes.NewBuffer(firstWssRequest.([]byte)) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, nil) + resp, err := adaptor.DoRequest(c, info, nil) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { - relayInfo.TargetWs = resp.(*websocket.Conn) - defer relayInfo.TargetWs.Close() + info.TargetWs = resp.(*websocket.Conn) + defer info.TargetWs.Close() } - usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, nil, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, priceData, "") + service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "") return nil } diff --git a/router/main.go b/router/main.go index 0d2bfdce..23576427 100644 --- a/router/main.go +++ b/router/main.go @@ -3,11 +3,12 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "os" "strings" + + "github.com/gin-gonic/gin" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { diff --git a/router/relay-router.go b/router/relay-router.go index cd656580..e0f05e97 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,11 +1,13 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/relay" + "one-api/types" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -62,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) { - // WebSocket 路由 + // WebSocket 路由(统一到 Relay) wsRouter := relayV1Router.Group("") wsRouter.Use(middleware.Distribute()) - wsRouter.GET("/realtime", controller.WssRelay) + wsRouter.GET("/realtime", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIRealtime) + }) } { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) - httpRouter.POST("/messages", controller.RelayClaude) - httpRouter.POST("/completions", controller.Relay) - httpRouter.POST("/chat/completions", controller.Relay) - httpRouter.POST("/edits", controller.Relay) - httpRouter.POST("/images/generations", controller.Relay) - httpRouter.POST("/images/edits", controller.Relay) + + // claude related routes + httpRouter.POST("/messages", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatClaude) + }) + + // chat related routes + httpRouter.POST("/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + httpRouter.POST("/chat/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // response related routes + httpRouter.POST("/responses", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIResponses) + }) + + // image related routes + httpRouter.POST("/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/generations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + + // embedding related routes + httpRouter.POST("/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatEmbedding) + }) + + // audio related routes + httpRouter.POST("/audio/transcriptions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/translations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/speech", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + + // rerank related routes + httpRouter.POST("/rerank", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatRerank) + }) + + // gemini relay routes + httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + httpRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + + // other relay routes + httpRouter.POST("/moderations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // not implemented httpRouter.POST("/images/variations", controller.RelayNotImplemented) - httpRouter.POST("/embeddings", controller.Relay) - httpRouter.POST("/engines/:model/embeddings", controller.Relay) - httpRouter.POST("/audio/transcriptions", controller.Relay) - httpRouter.POST("/audio/translations", controller.Relay) - httpRouter.POST("/audio/speech", controller.Relay) - httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -95,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) - httpRouter.POST("/moderations", controller.Relay) - httpRouter.POST("/rerank", controller.Relay) - httpRouter.POST("/models/*path", controller.Relay) } relayMjRouter := router.Group("/mj") @@ -121,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) { relayGeminiRouter.Use(middleware.Distribute()) { // Gemini API 路径格式: /v1beta/models/{model_name}:{action} - relayGeminiRouter.POST("/models/*path", controller.Relay) + relayGeminiRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) } } diff --git a/service/error.go b/service/error.go index 9672402d..ef5cbbde 100644 --- a/service/error.go +++ b/service/error.go @@ -85,7 +85,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if err != nil { return } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse err = common.Unmarshal(responseBody, &errResponse) diff --git a/common/http.go b/service/http.go similarity index 86% rename from common/http.go rename to service/http.go index d2e824ef..357a2e78 100644 --- a/common/http.go +++ b/service/http.go @@ -1,10 +1,12 @@ -package common +package service import ( "bytes" "fmt" "io" "net/http" + "one-api/common" + "one-api/logger" "github.com/gin-gonic/gin" ) @@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) { } err := httpResponse.Body.Close() if err != nil { - SysError("failed to close response body: " + err.Error()) + common.SysError("failed to close response body: " + err.Error()) } } @@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { _, err := io.Copy(c.Writer, body) if err != nil { - LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) + logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) } } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 0dae9a03..7a609c9f 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -5,7 +5,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/midjourney.go b/service/midjourney.go index 1fc19682..916d02d0 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -212,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - common.SysError("do request failed: " + err.Error()) + common.SysLog("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode @@ -233,7 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) respStr := string(responseBody) log.Printf("respStr: %s", respStr) if respStr == "" { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go new file mode 100644 index 00000000..08e3f68f --- /dev/null +++ b/service/pre_consume_quota.go @@ -0,0 +1,79 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/logger" + "one-api/model" + relaycommon "one-api/relay/common" + "one-api/types" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" +) + +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { + if preConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) + gopool.Go(func() { + relayInfoCopy := *relayInfo + + err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + if err != nil { + common.SysLog("error return pre-consumed quota: " + err.Error()) + } + }) + } +} + +// PreConsumeQuota checks if the user has enough quota to pre-consume. +// It returns the pre-consumed quota if successful, or an error if not. +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + } + if userQuota <= 0 { + return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if userQuota-preConsumedQuota < 0 { + return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + + trustQuota := common.GetTrustQuota() + + relayInfo.UserQuota = userQuota + if userQuota > trustQuota { + // 用户额度充足,判断令牌额度是否充足 + if !relayInfo.TokenUnlimited { + // 非无限令牌,判断令牌额度是否充足 + tokenQuota := c.GetInt("token_quota") + if tokenQuota > trustQuota { + // 令牌额度充足,信任令牌 + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) + } + } else { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId)) + } + } + + if preConsumedQuota > 0 { + err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) + if err != nil { + return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) + } + relayInfo.FinalPreConsumedQuota = preConsumedQuota + return preConsumedQuota, nil +} diff --git a/service/quota.go b/service/quota.go index d17a077c..8f65bd20 100644 --- a/service/quota.go +++ b/service/quota.go @@ -8,11 +8,12 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strings" "time" @@ -137,23 +138,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota := calculateAudioQuota(quotaInfo) if userQuota < quota { - return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)) + return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota)) } if !token.UnlimitedQuota && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } - common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) + logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) return nil } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + usage *dto.RealtimeUsage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -167,10 +168,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -204,8 +205,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) @@ -216,7 +217,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.InputTokens, @@ -226,7 +227,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -234,8 +234,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod }) } -func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -243,21 +242,21 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - cacheRatio := priceData.CacheRatio + completionRatio := relayInfo.PriceData.CompletionRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + cacheRatio := relayInfo.PriceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens - cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { promptTokens -= cacheTokens - isUsingCustomSettings := priceData.UsePrice || hasCustomModelRatio(modelName, priceData.ModelRatio) - if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings { - maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) + isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio) + if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings { + maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData) if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens { cacheCreationTokens = maybeCacheCreationTokens } @@ -266,7 +265,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } calculateQuota := 0.0 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { calculateQuota = float64(promptTokens) calculateQuota += float64(cacheTokens) * cacheRatio calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio @@ -291,23 +290,38 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游出错)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, @@ -317,7 +331,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -326,7 +339,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } -func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { +func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int { if priceData.CacheCreationRatio == 1 { return 0 } @@ -347,8 +360,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData (promptCacheCreatePrice - quotaPrice))) } -func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -362,10 +374,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -399,18 +411,33 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -419,7 +446,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.PromptTokens, @@ -429,7 +456,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -452,7 +478,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { return err } if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) if err != nil { @@ -510,7 +536,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) if err != nil { common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } diff --git a/service/sensitive.go b/service/sensitive.go index b3e3c4d6..25cfd46f 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -2,7 +2,6 @@ package service import ( "errors" - "fmt" "one-api/dto" "one-api/setting" "strings" @@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { return nil, nil } -func CheckSensitiveText(text string) ([]string, error) { - if ok, words := SensitiveWordContains(text); ok { - return words, errors.New("sensitive words detected") - } - return nil, nil -} - -func CheckSensitiveInput(input any) ([]string, error) { - switch v := input.(type) { - case string: - return CheckSensitiveText(v) - case []string: - var builder strings.Builder - for _, s := range v { - builder.WriteString(s) - } - return CheckSensitiveText(builder.String()) - } - return CheckSensitiveText(fmt.Sprintf("%v", input)) +func CheckSensitiveText(text string) (bool, []string) { + return SensitiveWordContains(text) } // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 diff --git a/service/token_counter.go b/service/token_counter.go index eed5b5ca..314fa593 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,8 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tiktoken-go/tokenizer" - "github.com/tiktoken-go/tokenizer/codec" "image" "log" "math" @@ -13,9 +11,14 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/types" "strings" "sync" "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" ) // tokenEncoderMap won't grow after initialization @@ -72,52 +75,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { return tkm } -func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - if imageUrl == nil { +func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { + if fileMeta == nil { return 0, fmt.Errorf("image_url_is_nil") } + + // Defaults for 4o/4.1/4.5 family unless overridden below baseTokens := 85 - if model == "glm-4v" { + tileTokens := 170 + + // Model classification + lowerModel := strings.ToLower(model) + + // Special cases from existing behavior + if strings.HasPrefix(lowerModel, "glm-4") { return 1047, nil } - if imageUrl.Detail == "low" { + + // Patch-based models (32x32 patches, capped at 1536, with multiplier) + isPatchBased := false + multiplier := 1.0 + switch { + case strings.Contains(lowerModel, "gpt-4.1-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.Contains(lowerModel, "gpt-4.1-nano"): + isPatchBased = true + multiplier = 2.46 + case strings.HasPrefix(lowerModel, "o4-mini"): + isPatchBased = true + multiplier = 1.72 + case strings.HasPrefix(lowerModel, "gpt-5-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.HasPrefix(lowerModel, "gpt-5-nano"): + isPatchBased = true + multiplier = 2.46 + } + + // Tile-based model tokens and bases per doc + if !isPatchBased { + if strings.HasPrefix(lowerModel, "gpt-4o-mini") { + baseTokens = 2833 + tileTokens = 5667 + } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) { + baseTokens = 70 + tileTokens = 140 + } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") { + baseTokens = 75 + tileTokens = 150 + } else if strings.Contains(lowerModel, "computer-use-preview") { + baseTokens = 65 + tileTokens = 129 + } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") { + baseTokens = 85 + tileTokens = 170 + } + } + + // Respect existing feature flags/short-circuits + if fileMeta.Detail == "low" && !isPatchBased { return baseTokens, nil } if !constant.GetMediaTokenNotStream && !stream { return 3 * baseTokens, nil } - - // 同步One API的图片计费逻辑 - if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - imageUrl.Detail = "high" + // Normalize detail + if fileMeta.Detail == "auto" || fileMeta.Detail == "" { + fileMeta.Detail = "high" } - - tileTokens := 170 - if strings.HasPrefix(model, "gpt-4o-mini") { - tileTokens = 5667 - baseTokens = 2833 - } - // 是否统计图片token + // Whether to count image tokens at all if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { - return 3 * baseTokens, nil - } + + // Decode image to get dimensions var config image.Config var err error var format string var b64str string - if strings.HasPrefix(imageUrl.Url, "http") { - config, format, err = DecodeUrlImageData(imageUrl.Url) + if strings.HasPrefix(fileMeta.Data, "http") { + config, format, err = DecodeUrlImageData(fileMeta.Data) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) + config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) } if err != nil { return 0, err } - imageUrl.MimeType = format + fileMeta.MimeType = format if config.Width == 0 || config.Height == 0 { // not an image @@ -125,60 +171,155 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data)) } - shortSide := config.Width - otherSide := config.Height - log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height) - // 缩放倍数 - scale := 1.0 - if config.Height < shortSide { - shortSide = config.Height - otherSide = config.Width + width := config.Width + height := config.Height + log.Printf("format: %s, width: %d, height: %d", format, width, height) + + if isPatchBased { + // 32x32 patch-based calculation with 1536 cap and model multiplier + ceilDiv := func(a, b int) int { return (a + b - 1) / b } + rawPatchesW := ceilDiv(width, 32) + rawPatchesH := ceilDiv(height, 32) + rawPatches := rawPatchesW * rawPatchesH + if rawPatches > 1536 { + // scale down + area := float64(width * height) + r := math.Sqrt(float64(32*32*1536) / area) + wScaled := float64(width) * r + hScaled := float64(height) * r + // adjust to fit whole number of patches after scaling + adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0) + adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0) + adj := math.Min(adjW, adjH) + if !math.IsNaN(adj) && adj > 0 { + r = r * adj + } + wScaled = float64(width) * r + hScaled = float64(height) * r + patchesW := math.Ceil(wScaled / 32.0) + patchesH := math.Ceil(hScaled / 32.0) + imageTokens := int(patchesW * patchesH) + if imageTokens > 1536 { + imageTokens = 1536 + } + return int(math.Round(float64(imageTokens) * multiplier)), nil + } + // below cap + imageTokens := rawPatches + return int(math.Round(float64(imageTokens) * multiplier)), nil } - // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 - if shortSide > 768 { - scale = float64(shortSide) / 768 - shortSide = 768 + // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc. + // Step 1: fit within 2048x2048 square + maxSide := math.Max(float64(width), float64(height)) + fitScale := 1.0 + if maxSide > 2048 { + fitScale = maxSide / 2048.0 } - // 将另一边按照相同的比例缩小,向上取整 - otherSide = int(math.Ceil(float64(otherSide) / scale)) - log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) - // 计算图片的token数量(边的长度除以512,向上取整) - tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) - log.Printf("tiles: %d", tiles) + fitW := int(math.Round(float64(width) / fitScale)) + fitH := int(math.Round(float64(height) / fitScale)) + + // Step 2: scale so that shortest side is exactly 768 + minSide := math.Min(float64(fitW), float64(fitH)) + if minSide == 0 { + return baseTokens, nil + } + shortScale := 768.0 / minSide + finalW := int(math.Round(float64(fitW) * shortScale)) + finalH := int(math.Round(float64(fitH) * shortScale)) + + // Count 512px tiles + tilesW := (finalW + 512 - 1) / 512 + tilesH := (finalH + 512 - 1) / 512 + tiles := tilesW * tilesH + + if common.DebugEnabled { + log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles) + } + return tiles*tileTokens + baseTokens, nil } -func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { - tkm := 0 - msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) - if err != nil { - return 0, err - } - tkm += msgTokens - if request.Tools != nil { - openaiTools := request.Tools - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) - } - } - toolTokens := CountTokenInput(countStr, request.Model) - tkm += 8 - tkm += toolTokens +func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + if meta == nil { + return 0, errors.New("token count meta is nil") } + if info.RelayFormat == types.RelayFormatOpenAIRealtime { + return 0, nil + } + + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + tkm := 0 + + if meta.TokenType == types.TokenTypeTextNumber { + tkm += utf8.RuneCountInString(meta.CombineText) + } else { + tkm += CountTextToken(meta.CombineText, model) + } + + if info.RelayFormat == types.RelayFormatOpenAI { + tkm += meta.ToolsCount * 8 + tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量 + tkm += meta.NameCount * 3 + tkm += 3 + } + + for _, file := range meta.Files { + switch file.FileType { + case types.FileTypeImage: + if info.RelayFormat == types.RelayFormatGemini { + tkm += 240 + } else { + token, err := getImageToken(file, model, info.IsStream) + if err != nil { + return 0, fmt.Errorf("error counting image token: %v", err) + } + tkm += token + } + case types.FileTypeAudio: + tkm += 100 + case types.FileTypeVideo: + tkm += 5000 + case types.FileTypeFile: + tkm += 5000 + } + } + + common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm) return tkm, nil } +//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { +// tkm := 0 +// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) +// if err != nil { +// return 0, err +// } +// tkm += msgTokens +// if request.Tools != nil { +// openaiTools := request.Tools +// countStr := "" +// for _, tool := range openaiTools { +// countStr = tool.Function.Name +// if tool.Function.Description != "" { +// countStr += tool.Function.Description +// } +// if tool.Function.Parameters != nil { +// countStr += fmt.Sprintf("%v", tool.Function.Parameters) +// } +// } +// toolTokens := CountTokenInput(countStr, request.Model) +// tkm += 8 +// tkm += toolTokens +// } +// +// return tkm, nil +//} + func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { tkm := 0 @@ -338,58 +479,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { - //recover when panic - tokenEncoder := getTokenEncoder(model) - // Reference: - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - // - // Every message follows <|start|>{role/name}\n{content}<|end|>\n - var tokensPerMessage int - var tokensPerName int - if model == "gpt-3.5-turbo-0301" { - tokensPerMessage = 4 - tokensPerName = -1 // If there's a name, the role is omitted - } else { - tokensPerMessage = 3 - tokensPerName = 1 - } - tokenNum := 0 - for _, message := range messages { - tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Content != nil { - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.GetImageMedia() - imageTokenNum, err := getImageToken(info, imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else if m.Type == dto.ContentTypeFile { - tokenNum += 5000 - } else if m.Type == dto.ContentTypeVideoUrl { - tokenNum += 5000 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) - } - } - } - } - tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil -} +//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { +// //recover when panic +// tokenEncoder := getTokenEncoder(model) +// // Reference: +// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +// // https://github.com/pkoukk/tiktoken-go/issues/6 +// // +// // Every message follows <|start|>{role/name}\n{content}<|end|>\n +// var tokensPerMessage int +// var tokensPerName int +// +// tokensPerMessage = 3 +// tokensPerName = 1 +// +// tokenNum := 0 +// for _, message := range messages { +// tokenNum += tokensPerMessage +// tokenNum += getTokenNum(tokenEncoder, message.Role) +// if message.Content != nil { +// if message.Name != nil { +// tokenNum += tokensPerName +// tokenNum += getTokenNum(tokenEncoder, *message.Name) +// } +// arrayContent := message.ParseContent() +// for _, m := range arrayContent { +// if m.Type == dto.ContentTypeImageURL { +// imageUrl := m.GetImageMedia() +// imageTokenNum, err := getImageToken(info, imageUrl, model, stream) +// if err != nil { +// return 0, err +// } +// tokenNum += imageTokenNum +// log.Printf("image token num: %d", imageTokenNum) +// } else if m.Type == dto.ContentTypeInputAudio { +// // TODO: 音频token数量计算 +// tokenNum += 100 +// } else if m.Type == dto.ContentTypeFile { +// tokenNum += 5000 +// } else if m.Type == dto.ContentTypeVideoUrl { +// tokenNum += 5000 +// } else { +// tokenNum += getTokenNum(tokenEncoder, m.Text) +// } +// } +// } +// } +// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> +// return tokenNum, nil +//} func CountTokenInput(input any, model string) int { switch v := input.(type) { diff --git a/service/user_notify.go b/service/user_notify.go index 96664007..7c864a1b 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -12,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -25,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -44,7 +44,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } diff --git a/setting/chat.go b/setting/chat.go index b97d65ce..bd1e26e3 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - common.SysError("error marshalling chats: " + err.Error()) + common.SysLog("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index d550b2c3..141463e1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 3f223bc3..5993cdee 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -89,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - common.SysError("error marshalling cache ratio: " + err.Error()) + common.SysLog("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index 86f4a8d1..c42553da 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - common.SysError("group ratio not found: " + name) + common.SysLog("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - common.SysError("error marshalling group-group ratio: " + err.Error()) + common.SysLog("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 4a19895e..ce822800 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -320,7 +320,7 @@ func ModelPrice2JSONString() string { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelPriceMap) + jsonBytes, err := common.Marshal(modelPriceMap) if err != nil { common.SysError("error marshalling model price: " + err.Error()) } @@ -359,7 +359,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := common.Unmarshal([]byte(jsonStr), &modelRatioMap) if err == nil { InvalidateExposedDataCache() } @@ -388,7 +388,7 @@ func GetModelRatio(name string) (float64, bool, string) { } func DefaultModelRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultModelRatio) + jsonBytes, err := common.Marshal(defaultModelRatio) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -420,7 +420,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := common.Unmarshal([]byte(jsonStr), &CompletionRatio) if err == nil { InvalidateExposedDataCache() } @@ -594,7 +594,7 @@ func ModelRatio2JSONString() string { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelRatioMap) + jsonBytes, err := common.Marshal(modelRatioMap) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -610,7 +610,7 @@ var imageRatioMapMutex sync.RWMutex func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() defer imageRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(imageRatioMap) + jsonBytes, err := common.Marshal(imageRatioMap) if err != nil { common.SysError("error marshalling cache ratio: " + err.Error()) } @@ -621,7 +621,7 @@ func UpdateImageRatioByJSONString(jsonStr string) error { imageRatioMapMutex.Lock() defer imageRatioMapMutex.Unlock() imageRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &imageRatioMap) + return common.Unmarshal([]byte(jsonStr), &imageRatioMap) } func GetImageRatio(name string) (float64, bool) { diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 0ae132d0..57e4beec 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - common.SysError("error marshalling user groups: " + err.Error()) + common.SysLog("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/error.go b/types/error.go index 5a143612..07486c27 100644 --- a/types/error.go +++ b/types/error.go @@ -39,12 +39,13 @@ const ( ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" // new api error - ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" - ErrorCodeModelPriceError ErrorCode = "model_price_error" - ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" - ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" - ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" - ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" + ErrorCodeModelPriceError ErrorCode = "model_price_error" + ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" + ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" + ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" + ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" // channel error ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" @@ -66,6 +67,7 @@ const ( ErrorCodeBadResponseBody ErrorCode = "bad_response_body" ErrorCodeEmptyResponse ErrorCode = "empty_response" ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" + ErrorCodeModelNotFound ErrorCode = "model_not_found" // sql error ErrorCodeQueryDataError ErrorCode = "query_data_error" @@ -118,7 +120,8 @@ func (e *NewAPIError) MaskSensitiveError() string { if e.Err == nil { return string(e.errorCode) } - return common.MaskSensitiveInfo(e.Err.Error()) + errStr := e.Err.Error() + return common.MaskSensitiveInfo(errStr) } func (e *NewAPIError) SetMessage(message string) { diff --git a/types/price_data.go b/types/price_data.go new file mode 100644 index 00000000..f6a92d7e --- /dev/null +++ b/types/price_data.go @@ -0,0 +1,31 @@ +package types + +import "fmt" + +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 + HasSpecialRatio bool +} + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + CompletionRatio float64 + CacheRatio float64 + CacheCreationRatio float64 + ImageRatio float64 + UsePrice bool + ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo +} + +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +func (p PriceData) ToSetting() string { + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} diff --git a/types/relay_format.go b/types/relay_format.go new file mode 100644 index 00000000..6d94a70b --- /dev/null +++ b/types/relay_format.go @@ -0,0 +1,18 @@ +package types + +type RelayFormat string + +const ( + RelayFormatOpenAI RelayFormat = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatOpenAIRealtime = "openai_realtime" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" + + RelayFormatTask = "task" + RelayFormatMjProxy = "mj_proxy" +) diff --git a/types/request_meta.go b/types/request_meta.go new file mode 100644 index 00000000..427bacb9 --- /dev/null +++ b/types/request_meta.go @@ -0,0 +1,45 @@ +package types + +type FileType string + +const ( + FileTypeImage FileType = "image" // Image file type + FileTypeAudio FileType = "audio" // Audio file type + FileTypeVideo FileType = "video" // Video file type + FileTypeFile FileType = "file" // Generic file type +) + +type TokenType string + +const ( + TokenTypeTextNumber TokenType = "text_number" // Text or number tokens + TokenTypeTokenizer TokenType = "tokenizer" // Tokenizer tokens + TokenTypeImage TokenType = "image" // Image tokens +) + +type TokenCountMeta struct { + TokenType TokenType `json:"token_type,omitempty"` // Type of tokens used in the request + CombineText string `json:"combine_text,omitempty"` // Combined text from all messages + ToolsCount int `json:"tools_count,omitempty"` // Number of tools used + NameCount int `json:"name_count,omitempty"` // Number of names in the request + MessagesCount int `json:"messages_count,omitempty"` // Number of messages in the request + Files []*FileMeta `json:"files,omitempty"` // List of files, each with type and content + MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens allowed in the request + + ImagePriceRatio float64 `json:"image_ratio,omitempty"` // Ratio for image size, if applicable + //IsStreaming bool `json:"is_streaming,omitempty"` // Indicates if the request is streaming +} + +type FileMeta struct { + FileType + MimeType string + Data string + Detail string +} + +type RequestMeta struct { + OriginalModelName string `json:"original_model_name"` + UserUsingGroup string `json:"user_using_group"` + PromptTokens int `json:"prompt_tokens"` + PreConsumedQuota int `json:"pre_consumed_quota"` +} diff --git a/web/src/components/layout/HeaderBar.js b/web/src/components/layout/HeaderBar.js index 8f4d1990..61f82ba2 100644 --- a/web/src/components/layout/HeaderBar.js +++ b/web/src/components/layout/HeaderBar.js @@ -648,7 +648,8 @@ const HeaderBar = ({ onMobileMenuToggle, drawerOpen }) => {
{ ); }; -export default HeaderBar; +export default HeaderBar; \ No newline at end of file