diff --git a/common/api_type.go b/common/api_type.go new file mode 100644 index 00000000..d9071236 --- /dev/null +++ b/common/api_type.go @@ -0,0 +1,71 @@ +package common + +import "one-api/constant" + +func ChannelType2APIType(channelType int) (int, bool) { + apiType := -1 + switch channelType { + case constant.ChannelTypeOpenAI: + apiType = constant.APITypeOpenAI + case constant.ChannelTypeAnthropic: + apiType = constant.APITypeAnthropic + case constant.ChannelTypeBaidu: + apiType = constant.APITypeBaidu + case constant.ChannelTypePaLM: + apiType = constant.APITypePaLM + case constant.ChannelTypeZhipu: + apiType = constant.APITypeZhipu + case constant.ChannelTypeAli: + apiType = constant.APITypeAli + case constant.ChannelTypeXunfei: + apiType = constant.APITypeXunfei + case constant.ChannelTypeAIProxyLibrary: + apiType = constant.APITypeAIProxyLibrary + case constant.ChannelTypeTencent: + apiType = constant.APITypeTencent + case constant.ChannelTypeGemini: + apiType = constant.APITypeGemini + case constant.ChannelTypeZhipu_v4: + apiType = constant.APITypeZhipuV4 + case constant.ChannelTypeOllama: + apiType = constant.APITypeOllama + case constant.ChannelTypePerplexity: + apiType = constant.APITypePerplexity + case constant.ChannelTypeAws: + apiType = constant.APITypeAws + case constant.ChannelTypeCohere: + apiType = constant.APITypeCohere + case constant.ChannelTypeDify: + apiType = constant.APITypeDify + case constant.ChannelTypeJina: + apiType = constant.APITypeJina + case constant.ChannelCloudflare: + apiType = constant.APITypeCloudflare + case constant.ChannelTypeSiliconFlow: + apiType = constant.APITypeSiliconFlow + case constant.ChannelTypeVertexAi: + apiType = constant.APITypeVertexAi + case constant.ChannelTypeMistral: + apiType = constant.APITypeMistral + case constant.ChannelTypeDeepSeek: + apiType = constant.APITypeDeepSeek + case constant.ChannelTypeMokaAI: + apiType = constant.APITypeMokaAI + case constant.ChannelTypeVolcEngine: + apiType = constant.APITypeVolcEngine + case constant.ChannelTypeBaiduV2: + apiType = constant.APITypeBaiduV2 + case constant.ChannelTypeOpenRouter: + apiType = constant.APITypeOpenRouter + case constant.ChannelTypeXinference: + apiType = constant.APITypeXinference + case constant.ChannelTypeXai: + apiType = constant.APITypeXai + case constant.ChannelTypeCoze: + apiType = constant.APITypeCoze + } + if apiType == -1 { + return constant.APITypeOpenAI, false + } + return apiType, true +} diff --git a/common/constants.go b/common/constants.go index 67625439..e4f5f047 100644 --- a/common/constants.go +++ b/common/constants.go @@ -193,111 +193,3 @@ const ( ChannelStatusManuallyDisabled = 2 // also don't use 0 ChannelStatusAutoDisabled = 3 ) - -const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeMidjourney = 2 - ChannelTypeAzure = 3 - ChannelTypeOllama = 4 - ChannelTypeMidjourneyPlus = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 - ChannelTypeAIProxyLibrary = 21 - ChannelTypeFastGPT = 22 - ChannelTypeTencent = 23 - ChannelTypeGemini = 24 - ChannelTypeMoonshot = 25 - ChannelTypeZhipu_v4 = 26 - ChannelTypePerplexity = 27 - ChannelTypeLingYiWanWu = 31 - ChannelTypeAws = 33 - ChannelTypeCohere = 34 - ChannelTypeMiniMax = 35 - ChannelTypeSunoAPI = 36 - ChannelTypeDify = 37 - ChannelTypeJina = 38 - ChannelCloudflare = 39 - ChannelTypeSiliconFlow = 40 - ChannelTypeVertexAi = 41 - ChannelTypeMistral = 42 - ChannelTypeDeepSeek = 43 - ChannelTypeMokaAI = 44 - ChannelTypeVolcEngine = 45 - ChannelTypeBaiduV2 = 46 - ChannelTypeXinference = 47 - ChannelTypeXai = 48 - ChannelTypeCoze = 49 - ChannelTypeKling = 50 - ChannelTypeJimeng = 51 - ChannelTypeDummy // this one is only for count, do not add any channel after this - -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "http://localhost:11434", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://api.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.tencentcloudapi.com", //23 - "https://generativelanguage.googleapis.com", //24 - "https://api.moonshot.cn", //25 - "https://open.bigmodel.cn", //26 - "https://api.perplexity.ai", //27 - "", //28 - "", //29 - "", //30 - "https://api.lingyiwanwu.com", //31 - "", //32 - "", //33 - "https://api.cohere.ai", //34 - "https://api.minimax.chat", //35 - "", //36 - "https://api.dify.ai", //37 - "https://api.jina.ai", //38 - "https://api.cloudflare.com", //39 - "https://api.siliconflow.cn", //40 - "", //41 - "https://api.mistral.ai", //42 - "https://api.deepseek.com", //43 - "https://api.moka.ai", //44 - "https://ark.cn-beijing.volces.com", //45 - "https://qianfan.baidubce.com", //46 - "", //47 - "https://api.x.ai", //48 - "https://api.coze.cn", //49 - "https://api.klingai.com", //50 - "https://visual.volcengineapi.com", //51 -} diff --git a/common/endpoint_type.go b/common/endpoint_type.go new file mode 100644 index 00000000..09174f23 --- /dev/null +++ b/common/endpoint_type.go @@ -0,0 +1,29 @@ +package common + +import "one-api/constant" + +// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点) +func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType { + var endpointTypes []constant.EndpointType + switch channelType { + case constant.ChannelTypeJina: + endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank} + case constant.ChannelTypeAws: + fallthrough + case constant.ChannelTypeAnthropic: + endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI} + case constant.ChannelTypeVertexAi: + fallthrough + case constant.ChannelTypeGemini: + endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI} + case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点 + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} + default: + if IsOpenAIResponseOnlyModel(modelName) { + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse} + } else { + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} + } + } + return endpointTypes +} diff --git a/common/gin.go b/common/gin.go index 0614f735..62c4c692 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,7 +4,9 @@ import ( "bytes" "github.com/gin-gonic/gin" "io" + "one-api/constant" "strings" + "time" ) const KeyRequestBody = "key_request_body" @@ -42,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } + +func SetContextKey(c *gin.Context, key constant.ContextKey, value any) { + c.Set(string(key), value) +} + +func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) { + return c.Get(string(key)) +} + +func GetContextKeyString(c *gin.Context, key constant.ContextKey) string { + return c.GetString(string(key)) +} + +func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int { + return c.GetInt(string(key)) +} + +func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool { + return c.GetBool(string(key)) +} + +func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string { + return c.GetStringSlice(string(key)) +} + +func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any { + return c.GetStringMap(string(key)) +} + +func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { + return c.GetTime(string(key)) +} diff --git a/common/init.go b/common/init.go index dd680db2..d70a09dd 100644 --- a/common/init.go +++ b/common/init.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "one-api/constant" "os" "path/filepath" "strconv" @@ -24,7 +25,7 @@ func printHelp() { fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") } -func InitCommonEnv() { +func InitEnv() { flag.Parse() if *PrintVersion { @@ -95,4 +96,25 @@ func InitCommonEnv() { GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) + + initConstantEnv() +} + +func initConstantEnv() { + constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120) + constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true) + constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) + // ForceStreamOption 覆盖请求参数,强制返回usage信息 + constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) + constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) + constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) + constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true) + constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") + constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) + constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) + constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) + // GenerateDefaultToken 是否生成初始令牌,默认关闭。 + constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) + // 是否启用错误日志 + constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) } diff --git a/common/model.go b/common/model.go new file mode 100644 index 00000000..6afb1540 --- /dev/null +++ b/common/model.go @@ -0,0 +1,21 @@ +package common + +import "strings" + +var ( + // OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses. + OpenAIResponseOnlyModels = []string{ + "o3-pro", + "o3-deep-research", + "o4-mini-deep-research", + } +) + +func IsOpenAIResponseOnlyModel(modelName string) bool { + for _, m := range OpenAIResponseOnlyModels { + if strings.Contains(m, modelName) { + return true + } + } + return false +} diff --git a/common/redis.go b/common/redis.go index 1efc217f..c7287837 100644 --- a/common/redis.go +++ b/common/redis.go @@ -16,6 +16,10 @@ import ( var RDB *redis.Client var RedisEnabled = true +func RedisKeyCacheSeconds() int { + return SyncFrequency +} + // InitRedisClient This function is called after init() func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { diff --git a/constant/README.md b/constant/README.md new file mode 100644 index 00000000..12a9ffad --- /dev/null +++ b/constant/README.md @@ -0,0 +1,26 @@ +# constant 包 (`/constant`) + +该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。 + +## 当前文件 + +| 文件 | 说明 | +|----------------------|---------------------------------------------------------------------| +| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 | +| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 | +| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 | +| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 | +| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 | +| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 | +| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 | +| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 | +| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 | +| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 | + +## 使用约定 + +1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。 +2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。 +3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。 + +> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。 \ No newline at end of file diff --git a/constant/api_type.go b/constant/api_type.go new file mode 100644 index 00000000..ae867870 --- /dev/null +++ b/constant/api_type.go @@ -0,0 +1,34 @@ +package constant + +const ( + APITypeOpenAI = iota + APITypeAnthropic + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei + APITypeAIProxyLibrary + APITypeTencent + APITypeGemini + APITypeZhipuV4 + APITypeOllama + APITypePerplexity + APITypeAws + APITypeCohere + APITypeDify + APITypeJina + APITypeCloudflare + APITypeSiliconFlow + APITypeVertexAi + APITypeMistral + APITypeDeepSeek + APITypeMokaAI + APITypeVolcEngine + APITypeBaiduV2 + APITypeOpenRouter + APITypeXinference + APITypeXai + APITypeCoze + APITypeDummy // this one is only for count, do not add any channel after this +) diff --git a/constant/cache_key.go b/constant/cache_key.go index daedfd40..0601396a 100644 --- a/constant/cache_key.go +++ b/constant/cache_key.go @@ -1,12 +1,5 @@ package constant -import "one-api/common" - -// 使用函数来避免初始化顺序带来的赋值问题 -func RedisKeyCacheSeconds() int { - return common.SyncFrequency -} - // Cache keys const ( UserGroupKeyFmt = "user_group:%d" diff --git a/constant/channel.go b/constant/channel.go new file mode 100644 index 00000000..224121e7 --- /dev/null +++ b/constant/channel.go @@ -0,0 +1,109 @@ +package constant + +const ( + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeMidjourney = 2 + ChannelTypeAzure = 3 + ChannelTypeOllama = 4 + ChannelTypeMidjourneyPlus = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 + ChannelTypeTencent = 23 + ChannelTypeGemini = 24 + ChannelTypeMoonshot = 25 + ChannelTypeZhipu_v4 = 26 + ChannelTypePerplexity = 27 + ChannelTypeLingYiWanWu = 31 + ChannelTypeAws = 33 + ChannelTypeCohere = 34 + ChannelTypeMiniMax = 35 + ChannelTypeSunoAPI = 36 + ChannelTypeDify = 37 + ChannelTypeJina = 38 + ChannelCloudflare = 39 + ChannelTypeSiliconFlow = 40 + ChannelTypeVertexAi = 41 + ChannelTypeMistral = 42 + ChannelTypeDeepSeek = 43 + ChannelTypeMokaAI = 44 + ChannelTypeVolcEngine = 45 + ChannelTypeBaiduV2 = 46 + ChannelTypeXinference = 47 + ChannelTypeXai = 48 + ChannelTypeCoze = 49 + ChannelTypeKling = 50 + ChannelTypeJimeng = 51 + ChannelTypeDummy // this one is only for count, do not add any channel after this + +) + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "http://localhost:11434", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://api.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.tencentcloudapi.com", //23 + "https://generativelanguage.googleapis.com", //24 + "https://api.moonshot.cn", //25 + "https://open.bigmodel.cn", //26 + "https://api.perplexity.ai", //27 + "", //28 + "", //29 + "", //30 + "https://api.lingyiwanwu.com", //31 + "", //32 + "", //33 + "https://api.cohere.ai", //34 + "https://api.minimax.chat", //35 + "", //36 + "https://api.dify.ai", //37 + "https://api.jina.ai", //38 + "https://api.cloudflare.com", //39 + "https://api.siliconflow.cn", //40 + "", //41 + "https://api.mistral.ai", //42 + "https://api.deepseek.com", //43 + "https://api.moka.ai", //44 + "https://ark.cn-beijing.volces.com", //45 + "https://qianfan.baidubce.com", //46 + "", //47 + "https://api.x.ai", //48 + "https://api.coze.cn", //49 + "https://api.klingai.com", //50 + "https://visual.volcengineapi.com", //51 +} diff --git a/constant/context_key.go b/constant/context_key.go index 895b0fcb..71e02f01 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -1,11 +1,35 @@ package constant +type ContextKey string + const ( - ContextKeyRequestStartTime = "request_start_time" - ContextKeyUserSetting = "user_setting" - ContextKeyUserQuota = "user_quota" - ContextKeyUserStatus = "user_status" - ContextKeyUserEmail = "user_email" - ContextKeyUserGroup = "user_group" - ContextKeyUsingGroup = "group" + ContextKeyOriginalModel ContextKey = "original_model" + ContextKeyRequestStartTime ContextKey = "request_start_time" + + /* token related keys */ + ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota" + ContextKeyTokenKey ContextKey = "token_key" + ContextKeyTokenId ContextKey = "token_id" + ContextKeyTokenGroup ContextKey = "token_group" + ContextKeyTokenAllowIps ContextKey = "allow_ips" + ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" + ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" + ContextKeyTokenModelLimit ContextKey = "token_model_limit" + + /* channel related keys */ + ContextKeyBaseUrl ContextKey = "base_url" + ContextKeyChannelType ContextKey = "channel_type" + ContextKeyChannelId ContextKey = "channel_id" + ContextKeyChannelSetting ContextKey = "channel_setting" + ContextKeyParamOverride ContextKey = "param_override" + + /* user related keys */ + ContextKeyUserId ContextKey = "id" + ContextKeyUserSetting ContextKey = "user_setting" + ContextKeyUserQuota ContextKey = "user_quota" + ContextKeyUserStatus ContextKey = "user_status" + ContextKeyUserEmail ContextKey = "user_email" + ContextKeyUserGroup ContextKey = "user_group" + ContextKeyUsingGroup ContextKey = "group" + ContextKeyUserName ContextKey = "username" ) diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go new file mode 100644 index 00000000..a1b840db --- /dev/null +++ b/constant/endpoint_type.go @@ -0,0 +1,11 @@ +package constant + +type EndpointType string + +const ( + EndpointTypeOpenAI EndpointType = "openai" + EndpointTypeOpenAIResponse EndpointType = "openai-response" + EndpointTypeAnthropic EndpointType = "anthropic" + EndpointTypeGemini EndpointType = "gemini" + EndpointTypeJinaRerank EndpointType = "jina-rerank" +) diff --git a/constant/env.go b/constant/env.go index f33c67ff..8bc2f131 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,9 +1,5 @@ package constant -import ( - "one-api/common" -) - var StreamingTimeout int var DifyDebug bool var MaxFileDownloadMB int @@ -17,39 +13,3 @@ var NotifyLimitCount int var NotificationLimitDurationMinute int var GenerateDefaultToken bool var ErrorLogEnabled bool - -//var GeminiModelMap = map[string]string{ -// "gemini-1.0-pro": "v1", -//} - -func InitEnv() { - StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120) - DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) - MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) - // ForceStreamOption 覆盖请求参数,强制返回usage信息 - ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) - GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) - GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) - UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") - GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) - NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) - NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) - // GenerateDefaultToken 是否生成初始令牌,默认关闭。 - GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) - // 是否启用错误日志 - ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) - - //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) - //if modelVersionMapStr == "" { - // return - //} - //for _, pair := range strings.Split(modelVersionMapStr, ",") { - // parts := strings.Split(pair, ":") - // if len(parts) == 2 { - // GeminiModelMap[parts[0]] = parts[1] - // } else { - // common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) - // } - //} -} diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 9bf5d1fe..3c92c78b 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/model" "one-api/service" "one-api/setting" @@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) { } func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case constant.ChannelTypeOpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case constant.ChannelTypeCustom: baseURL = channel.GetBaseURL() //case common.ChannelTypeOpenAISB: // return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case constant.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case constant.ChannelTypeAPI2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case constant.ChannelTypeAIGC2D: return updateChannelAIGC2DBalance(channel) - case common.ChannelTypeSiliconFlow: + case constant.ChannelTypeSiliconFlow: return updateChannelSiliconFlowBalance(channel) - case common.ChannelTypeDeepSeek: + case constant.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return updateChannelOpenRouterBalance(channel) - case common.ChannelTypeMoonshot: + case constant.ChannelTypeMoonshot: return updateChannelMoonshotBalance(channel) default: return 0, errors.New("尚未实现") diff --git a/controller/channel-test.go b/controller/channel-test.go index b3badf35..92118540 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -11,12 +11,12 @@ import ( "net/http/httptest" "net/url" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" - "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "strconv" @@ -31,19 +31,19 @@ import ( func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() - if channel.Type == common.ChannelTypeMidjourney { + if channel.Type == constant.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } - if channel.Type == common.ChannelTypeMidjourneyPlus { - return errors.New("midjourney plus channel test is not supported!!!"), nil + if channel.Type == constant.ChannelTypeMidjourneyPlus { + return errors.New("midjourney plus channel test is not supported"), nil } - if channel.Type == common.ChannelTypeSunoAPI { + if channel.Type == constant.ChannelTypeSunoAPI { return errors.New("suno channel test is not supported"), nil } - if channel.Type == common.ChannelTypeKling { + if channel.Type == constant.ChannelTypeKling { return errors.New("kling channel test is not supported"), nil } - if channel.Type == common.ChannelTypeJimeng { + if channel.Type == constant.ChannelTypeJimeng { return errors.New("jimeng channel test is not supported"), nil } w := httptest.NewRecorder() @@ -56,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型 strings.Contains(testModel, "embed") || - channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 + channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 requestPath = "/v1/embeddings" // 修改请求路径 } @@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } testModel = info.UpstreamModelName - apiType, _ := constant.ChannelType2APIType(channel.Type) + apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil diff --git a/controller/channel.go b/controller/channel.go index e46b38f5..98ef3c08 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/model" "strconv" "strings" @@ -125,7 +126,7 @@ func GetAllChannels(c *gin.Context) { order = "id desc" } - err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error + err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error if err != nil { c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) return @@ -181,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) { return } - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) switch channel.Type { - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) - case common.ChannelTypeAli: + case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) } body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) @@ -213,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) { var ids []string for _, model := range result.Data { id := model.ID - if channel.Type == common.ChannelTypeGemini { + if channel.Type == constant.ChannelTypeGemini { id = strings.TrimPrefix(id, "models/") } ids = append(ids, id) @@ -388,7 +389,7 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") - if channel.Type == common.ChannelTypeVertexAi { + if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -613,7 +614,7 @@ func UpdateChannel(c *gin.Context) { }) return } - if channel.Type == common.ChannelTypeVertexAi { + if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -668,7 +669,7 @@ func FetchModels(c *gin.Context) { baseURL := req.BaseURL if baseURL == "" { - baseURL = common.ChannelBaseURLs[req.Type] + baseURL = constant.ChannelBaseURLs[req.Type] } client := &http.Client{} diff --git a/controller/model.go b/controller/model.go index 78bd32d6..360fffa6 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "github.com/gin-gonic/gin" "github.com/samber/lo" "net/http" "one-api/common" @@ -14,10 +15,7 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/setting" - - "github.com/gin-gonic/gin" ) // https://platform.openai.com/docs/api-reference/models/list @@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels var openAIModelsMap map[string]dto.OpenAIModels var channelId2Models map[int][]string -func getPermission() []dto.OpenAIModelPermission { - var permission []dto.OpenAIModelPermission - permission = append(permission, dto.OpenAIModelPermission{ - Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", - Object: "model_permission", - Created: 1626777600, - AllowCreateEngine: true, - AllowSampling: true, - AllowLogprobs: true, - AllowSearchIndices: false, - AllowView: true, - AllowFineTuning: false, - Organization: "*", - Group: nil, - IsBlocking: false, - }) - return permission -} - func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility - permission := getPermission() - for i := 0; i < relayconstant.APITypeDummy; i++ { - if i == relayconstant.APITypeAIProxyLibrary { + for i := 0; i < constant.APITypeDummy; i++ { + if i == constant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) @@ -57,69 +35,51 @@ func init() { modelNames := adaptor.GetModelList() for _, modelName := range modelNames { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: channelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, }) } } for _, modelName := range ai360.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: ai360.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: ai360.ChannelName, }) } for _, modelName := range moonshot.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: moonshot.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: moonshot.ChannelName, }) } for _, modelName := range lingyiwanwu.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: lingyiwanwu.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: lingyiwanwu.ChannelName, }) } for _, modelName := range minimax.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: minimax.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: minimax.ChannelName, }) } for modelName, _ := range constant.MidjourneyModel2Action { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "midjourney", - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "midjourney", }) } openAIModelsMap = make(map[string]dto.OpenAIModels) @@ -127,9 +87,9 @@ func init() { openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) - for i := 1; i <= common.ChannelTypeDummy; i++ { - apiType, success := relayconstant.ChannelType2APIType(i) - if !success || apiType == relayconstant.APITypeAIProxyLibrary { + for i := 1; i <= constant.ChannelTypeDummy; i++ { + apiType, success := common.ChannelType2APIType(i) + if !success || apiType == constant.APITypeAIProxyLibrary { continue } meta := &relaycommon.RelayInfo{ChannelType: i} @@ -144,11 +104,10 @@ func init() { func ListModels(c *gin.Context) { userOpenAiModels := make([]dto.OpenAIModels, 0) - permission := getPermission() - modelLimitEnable := c.GetBool("token_model_limit_enabled") + modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { - s, ok := c.Get("token_model_limit") + s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) @@ -156,17 +115,16 @@ func ListModels(c *gin.Context) { tokenModelLimit = map[string]bool{} } for allowModel, _ := range tokenModelLimit { - if _, ok := openAIModelsMap[allowModel]; ok { - userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel]) + if oaiModel, ok := openAIModelsMap[allowModel]; ok { + oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel) + userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: allowModel, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Permission: permission, - Root: allowModel, - Parent: nil, + Id: allowModel, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel), }) } } @@ -181,14 +139,14 @@ func ListModels(c *gin.Context) { return } group := userGroup - tokenGroup := c.GetString("token_group") + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) if tokenGroup != "" { group = tokenGroup } var models []string if tokenGroup == "auto" { for _, autoGroup := range setting.AutoGroups { - groupModels := model.GetGroupModels(autoGroup) + groupModels := model.GetGroupEnabledModels(autoGroup) for _, g := range groupModels { if !common.StringsContains(models, g) { models = append(models, g) @@ -196,20 +154,19 @@ func ListModels(c *gin.Context) { } } } else { - models = model.GetGroupModels(group) + models = model.GetGroupEnabledModels(group) } - for _, s := range models { - if _, ok := openAIModelsMap[s]; ok { - userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) + for _, modelName := range models { + if oaiModel, ok := openAIModelsMap[modelName]; ok { + oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName) + userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: s, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Permission: permission, - Root: s, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName), }) } } diff --git a/controller/playground.go b/controller/playground.go index 10393250..33471455 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -65,7 +65,7 @@ func Playground(c *gin.Context) { return } middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) - c.Set(constant.ContextKeyRequestStartTime, time.Now()) + common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) // Write user context to ensure acceptUnsetRatio is available userId := c.GetInt("id") diff --git a/controller/relay.go b/controller/relay.go index 4da4262b..e375120b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -8,12 +8,12 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" constant2 "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" - "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" @@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode } func Relay(c *gin.Context) { - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") @@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) { return } - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 @@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry } if openaiErr.StatusCode == http.StatusBadRequest { channelType := c.GetInt("channel_type") - if channelType == common.ChannelTypeAnthropic { + if channelType == constant.ChannelTypeAnthropic { return true } return false diff --git a/controller/task_video.go b/controller/task_video.go index a17351b5..b62978a7 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -51,7 +51,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha } func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } diff --git a/controller/user.go b/controller/user.go index 340a1e93..ca161f42 100644 --- a/controller/user.go +++ b/controller/user.go @@ -487,7 +487,7 @@ func GetUserModels(c *gin.Context) { groups := setting.GetUserUsableGroups(user.Group) var models []string for group := range groups { - for _, g := range model.GetGroupModels(group) { + for _, g := range model.GetGroupEnabledModels(group) { if !common.StringsContains(models, g) { models = append(models, g) } diff --git a/dto/pricing.go b/dto/pricing.go index ee77c098..0f317d9d 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -1,26 +1,11 @@ package dto -type OpenAIModelPermission struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - AllowCreateEngine bool `json:"allow_create_engine"` - AllowSampling bool `json:"allow_sampling"` - AllowLogprobs bool `json:"allow_logprobs"` - AllowSearchIndices bool `json:"allow_search_indices"` - AllowView bool `json:"allow_view"` - AllowFineTuning bool `json:"allow_fine_tuning"` - Organization string `json:"organization"` - Group *string `json:"group"` - IsBlocking bool `json:"is_blocking"` -} +import "one-api/constant" type OpenAIModels struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - OwnedBy string `json:"owned_by"` - Permission []OpenAIModelPermission `json:"permission"` - Root string `json:"root"` - Parent *string `json:"parent"` + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } diff --git a/main.go b/main.go index 5e7656e9..727d5db6 100644 --- a/main.go +++ b/main.go @@ -169,10 +169,8 @@ func InitResources() error { common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } - // 加载旧的(common)环境变量 - common.InitCommonEnv() - // 加载constants的环境变量 - constant.InitEnv() + // 加载环境变量 + common.InitEnv() // Initialize model settings ratio_setting.InitRatioSettings() @@ -193,6 +191,9 @@ func InitResources() error { // Initialize options, should after model.InitDB() model.InitOptionMap() + // 初始化模型 + model.GetPricing() + // Initialize SQL Database err = model.InitLogDB() if err != nil { diff --git a/middleware/distributor.go b/middleware/distributor.go index 0a6a9af4..17916e7a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,7 +25,7 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - allowIpsMap := c.GetStringMap("allow_ips") + allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps) if len(allowIpsMap) != 0 { clientIp := c.ClientIP() if _, ok := allowIpsMap[clientIp]; !ok { @@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) { } } var channel *model.Channel - channelId, ok := c.Get("specific_channel_id") + channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) modelRequest, shouldSelectChannel, err := getModelRequest(c) if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup := c.GetString(constant.ContextKeyUserGroup) - tokenGroup := c.GetString("token_group") + userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if tokenGroup != "" { // check common.UserUsableGroups[userGroup] if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { @@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) { } userGroup = tokenGroup } - c.Set(constant.ContextKeyUsingGroup, userGroup) + common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user // check token model mapping - modelLimitEnable := c.GetBool("token_model_limit_enabled") + modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { - s, ok := c.Get("token_model_limit") + s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) @@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) { } } } - c.Set(constant.ContextKeyRequestStartTime, time.Now()) + common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } @@ -261,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 switch channel.Type { - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: c.Set("api_version", channel.Other) - case common.ChannelTypeVertexAi: + case constant.ChannelTypeVertexAi: c.Set("region", channel.Other) - case common.ChannelTypeXunfei: + case constant.ChannelTypeXunfei: c.Set("api_version", channel.Other) - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: c.Set("api_version", channel.Other) - case common.ChannelTypeAli: + case constant.ChannelTypeAli: c.Set("plugin", channel.Other) - case common.ChannelCloudflare: + case constant.ChannelCloudflare: c.Set("api_version", channel.Other) - case common.ChannelTypeMokaAI: + case constant.ChannelTypeMokaAI: c.Set("api_version", channel.Other) - case common.ChannelTypeCoze: + case constant.ChannelTypeCoze: c.Set("bot_id", channel.Other) } } diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 34caa59b..14d9a737 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) { successMaxCount := setting.ModelRequestRateLimitSuccessCount // 获取分组 - group := c.GetString("token_group") + group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if group == "" { - group = c.GetString(constant.ContextKeyUserGroup) + group = common.GetContextKeyString(c, constant.ContextKeyUserGroup) } //获取分组的限流配置 diff --git a/model/ability.go b/model/ability.go index 96a9ef6a..fb5301fe 100644 --- a/model/ability.go +++ b/model/ability.go @@ -21,7 +21,22 @@ type Ability struct { Tag *string `json:"tag" gorm:"index"` } -func GetGroupModels(group string) []string { +type AbilityWithChannel struct { + Ability + ChannelType int `json:"channel_type"` +} + +func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) { + var abilities []AbilityWithChannel + err := DB.Table("abilities"). + Select("abilities.*, channels.type as channel_type"). + Joins("left join channels on abilities.channel_id = channels.id"). + Where("abilities.enabled = ?", true). + Scan(&abilities).Error + return abilities, err +} + +func GetGroupEnabledModels(group string) []string { var models []string // Find distinct models DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models) @@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) { var priorities []int err := DB.Model(&Ability{}). Select("DISTINCT(priority)"). - Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal). + Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true). Order("priority DESC"). // 按优先级降序排序 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 @@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) { } func getChannelQuery(group string, model string, retry int) *gorm.DB { - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal) - channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true) + channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery) if retry != 0 { priority, err := getPriority(group, model, retry) if err != nil { common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) } else { - channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority) + channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority) } } diff --git a/model/pricing.go b/model/pricing.go index 74a25f2d..0c0216f1 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -1,20 +1,24 @@ package model import ( + "fmt" "one-api/common" + "one-api/constant" "one-api/setting/ratio_setting" + "one-api/types" "sync" "time" ) type Pricing struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - OwnerBy string `json:"owner_by"` - CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_groups,omitempty"` + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_groups"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } var ( @@ -23,47 +27,89 @@ var ( updatePricingLock sync.Mutex ) -func GetPricing() []Pricing { - updatePricingLock.Lock() - defer updatePricingLock.Unlock() +var ( + modelSupportEndpointTypes = make(map[string][]constant.EndpointType) + modelSupportEndpointsLock = sync.RWMutex{} +) +func GetPricing() []Pricing { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { - updatePricing() + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + // Double check after acquiring the lock + if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { + modelSupportEndpointsLock.Lock() + defer modelSupportEndpointsLock.Unlock() + updatePricing() + } } - //if group != "" { - // userPricingMap := make([]Pricing, 0) - // models := GetGroupModels(group) - // for _, pricing := range pricingMap { - // if !common.StringsContains(models, pricing.ModelName) { - // pricing.Available = false - // } - // userPricingMap = append(userPricingMap, pricing) - // } - // return userPricingMap - //} return pricingMap } +func GetModelSupportEndpointTypes(model string) []constant.EndpointType { + if model == "" { + return make([]constant.EndpointType, 0) + } + modelSupportEndpointsLock.RLock() + defer modelSupportEndpointsLock.RUnlock() + if endpoints, ok := modelSupportEndpointTypes[model]; ok { + return endpoints + } + return make([]constant.EndpointType, 0) +} + func updatePricing() { //modelRatios := common.GetModelRatios() - enableAbilities := GetAllEnableAbilities() - modelGroupsMap := make(map[string][]string) + enableAbilities, err := GetAllEnableAbilityWithChannels() + if err != nil { + common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + return + } + modelGroupsMap := make(map[string]*types.Set[string]) + for _, ability := range enableAbilities { - groups := modelGroupsMap[ability.Model] - if groups == nil { - groups = make([]string, 0) + groups, ok := modelGroupsMap[ability.Model] + if !ok { + groups = types.NewSet[string]() + modelGroupsMap[ability.Model] = groups } - if !common.StringsContains(groups, ability.Group) { - groups = append(groups, ability.Group) + groups.Add(ability.Group) + } + + //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 + modelSupportEndpointsStr := make(map[string][]string) + + for _, ability := range enableAbilities { + endpoints, ok := modelSupportEndpointsStr[ability.Model] + if !ok { + endpoints = make([]string, 0) + modelSupportEndpointsStr[ability.Model] = endpoints } - modelGroupsMap[ability.Model] = groups + channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) + for _, channelType := range channelTypes { + if !common.StringsContains(endpoints, string(channelType)) { + endpoints = append(endpoints, string(channelType)) + } + } + modelSupportEndpointsStr[ability.Model] = endpoints + } + + modelSupportEndpointTypes = make(map[string][]constant.EndpointType) + for model, endpoints := range modelSupportEndpointsStr { + supportedEndpoints := make([]constant.EndpointType, 0) + for _, endpointStr := range endpoints { + endpointType := constant.EndpointType(endpointStr) + supportedEndpoints = append(supportedEndpoints, endpointType) + } + modelSupportEndpointTypes[model] = supportedEndpoints } pricingMap = make([]Pricing, 0) for model, groups := range modelGroupsMap { pricing := Pricing{ - ModelName: model, - EnableGroup: groups, + ModelName: model, + EnableGroup: groups.Items(), + SupportedEndpointTypes: modelSupportEndpointTypes[model], } modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) if findPrice { diff --git a/model/token_cache.go b/model/token_cache.go index a4b0beae..5399dbc8 100644 --- a/model/token_cache.go +++ b/model/token_cache.go @@ -10,7 +10,7 @@ import ( func cacheSetToken(token Token) error { key := common.GenerateHMAC(token.Key) token.Clean() - err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second) + err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second) if err != nil { return err } diff --git a/model/user_cache.go b/model/user_cache.go index e673defc..b4bc2f1e 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -24,12 +24,12 @@ type UserBase struct { } func (user *UserBase) WriteContext(c *gin.Context) { - c.Set(constant.ContextKeyUserGroup, user.Group) - c.Set(constant.ContextKeyUserQuota, user.Quota) - c.Set(constant.ContextKeyUserStatus, user.Status) - c.Set(constant.ContextKeyUserEmail, user.Email) - c.Set("username", user.Username) - c.Set(constant.ContextKeyUserSetting, user.GetSetting()) + common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group) + common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota) + common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status) + common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email) + common.SetContextKey(c, constant.ContextKeyUserName, user.Username) + common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting()) } func (user *UserBase) GetSetting() map[string]interface{} { @@ -70,7 +70,7 @@ func updateUserCache(user User) error { return common.RedisHSetObj( getUserCacheKey(user.Id), user.ToBaseUser(), - time.Duration(constant.RedisKeyCacheSeconds())*time.Second, + time.Duration(common.RedisKeyCacheSeconds())*time.Second, ) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 424fd3df..711284f1 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,8 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" - "one-api/common" - constant2 "one-api/constant" + "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" @@ -21,7 +20,7 @@ import ( "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" "one-api/relay/common_handler" - "one-api/relay/constant" + relayconstant "one-api/relay/constant" "one-api/service" "path/filepath" "strings" @@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType // initialize ThinkingContentInfo when thinking_to_content is enabled - if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content { + if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content { info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl = "wss://" + baseUrl @@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } switch info.ChannelType { - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: apiVersion := info.ApiVersion if apiVersion == "" { - apiVersion = constant2.AzureDefaultAPIVersion + apiVersion = constant.AzureDefaultAPIVersion } // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(info.RequestURLPath, "?")[0] @@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { task := strings.TrimPrefix(requestURL, "/v1/") // 特殊处理 responses API - if info.RelayMode == constant.RelayModeResponses { + if info.RelayMode == relayconstant.RelayModeResponses { requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview") return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil } model_ := info.UpstreamModelName // 2025年5月10日后创建的渠道不移除. - if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime { + if info.ChannelCreateTime < constant.AzureNoRemoveDotTime { model_ = strings.Replace(model_, ".", "", -1) } // https://github.com/songquanpeng/one-api/issues/67 requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) - case common.ChannelTypeCustom: + case constant.ChannelTypeCustom: url := info.BaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil @@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, header) - if info.ChannelType == common.ChannelTypeAzure { + if info.ChannelType == constant.ChannelTypeAzure { header.Set("api-key", info.ApiKey) return nil } - if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { + if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization { header.Set("OpenAI-Organization", info.Organization) } - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { swp := c.Request.Header.Get("Sec-WebSocket-Protocol") if swp != "" { items := []string{ @@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * } else { header.Set("Authorization", "Bearer "+info.ApiKey) } - if info.ChannelType == common.ChannelTypeOpenRouter { + if info.ChannelType == constant.ChannelTypeOpenRouter { header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") header.Set("X-Title", "New API") } @@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure { + if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure { request.StreamOptions = nil } - if info.ChannelType == common.ChannelTypeOpenRouter { + if info.ChannelType == constant.ChannelTypeOpenRouter { if len(request.Usage) == 0 { request.Usage = json.RawMessage(`{"include":true}`) } @@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { a.ResponseFormat = request.ResponseFormat - if info.RelayMode == constant.RelayModeAudioSpeech { + if info.RelayMode == relayconstant.RelayModeAudioSpeech { jsonData, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("error marshalling object: %w", err) @@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { - case constant.RelayModeImagesEdits: + case relayconstant.RelayModeImagesEdits: var requestBody bytes.Buffer writer := multipart.NewWriter(&requestBody) @@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - if info.RelayMode == constant.RelayModeAudioTranscription || - info.RelayMode == constant.RelayModeAudioTranslation || - info.RelayMode == constant.RelayModeImagesEdits { + if info.RelayMode == relayconstant.RelayModeAudioTranscription || + info.RelayMode == relayconstant.RelayModeAudioTranslation || + info.RelayMode == relayconstant.RelayModeImagesEdits { return channel.DoFormRequest(a, c, info, requestBody) - } else if info.RelayMode == constant.RelayModeRealtime { + } else if info.RelayMode == relayconstant.RelayModeRealtime { return channel.DoWssRequest(a, c, info, requestBody) } else { return channel.DoApiRequest(a, c, info, requestBody) @@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { - case constant.RelayModeRealtime: + case relayconstant.RelayModeRealtime: err, usage = OpenaiRealtimeHandler(c, info) - case constant.RelayModeAudioSpeech: + case relayconstant.RelayModeAudioSpeech: err, usage = OpenaiTTSHandler(c, resp, info) - case constant.RelayModeAudioTranslation: + case relayconstant.RelayModeAudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relayconstant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) - case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: err, usage = OpenaiHandlerWithUsage(c, resp, info) - case constant.RelayModeRerank: + case relayconstant.RelayModeRerank: err, usage = common_handler.RerankHandler(c, info, resp) - case constant.RelayModeResponses: + case relayconstant.RelayModeResponses: if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { @@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom func (a *Adaptor) GetModelList() []string { switch a.ChannelType { - case common.ChannelType360: + case constant.ChannelType360: return ai360.ModelList - case common.ChannelTypeMoonshot: + case constant.ChannelTypeMoonshot: return moonshot.ModelList - case common.ChannelTypeLingYiWanWu: + case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ModelList - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.ModelList - case common.ChannelTypeXinference: + case constant.ChannelTypeXinference: return xinference.ModelList - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return openrouter.ModelList default: return ModelList @@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { switch a.ChannelType { - case common.ChannelType360: + case constant.ChannelType360: return ai360.ChannelName - case common.ChannelTypeMoonshot: + case constant.ChannelTypeMoonshot: return moonshot.ChannelName - case common.ChannelTypeLingYiWanWu: + case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ChannelName - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.ChannelName - case common.ChannelTypeXinference: + case constant.ChannelTypeXinference: return xinference.ChannelName - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return openrouter.ChannelName default: return ChannelName diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 8a7d55d5..7c283bd0 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -168,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { - if info.ChannelType == common.ChannelTypeDeepSeek { + if info.ChannelType == constant.ChannelTypeDeepSeek { if usage.PromptCacheHitTokens != 0 { usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5fd94788..37161c16 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -113,17 +113,17 @@ type RelayInfo struct { // 定义支持流式选项的通道类型 var streamSupportedChannels = map[int]bool{ - common.ChannelTypeOpenAI: true, - common.ChannelTypeAnthropic: true, - common.ChannelTypeAws: true, - common.ChannelTypeGemini: true, - common.ChannelCloudflare: true, - common.ChannelTypeAzure: true, - common.ChannelTypeVolcEngine: true, - common.ChannelTypeOllama: true, - common.ChannelTypeXai: true, - common.ChannelTypeDeepSeek: true, - common.ChannelTypeBaiduV2: true, + constant.ChannelTypeOpenAI: true, + constant.ChannelTypeAnthropic: true, + constant.ChannelTypeAws: true, + constant.ChannelTypeGemini: true, + constant.ChannelCloudflare: true, + constant.ChannelTypeAzure: true, + constant.ChannelTypeVolcEngine: true, + constant.ChannelTypeOllama: true, + constant.ChannelTypeXai: true, + constant.ChannelTypeDeepSeek: true, + constant.ChannelTypeBaiduV2: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { @@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo { } func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := c.GetInt("channel_type") - channelId := c.GetInt("channel_id") - channelSetting := c.GetStringMap("channel_setting") - paramOverride := c.GetStringMap("param_override") + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) - tokenId := c.GetInt("token_id") - tokenKey := c.GetString("token_key") - userId := c.GetInt("id") - tokenUnlimited := c.GetBool("token_unlimited_quota") - startTime := c.GetTime(constant.ContextKeyRequestStartTime) + tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) + tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) + userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) + tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) + startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) // firstResponseTime = time.Now() - 1 second - apiType, _ := relayconstant.ChannelType2APIType(channelType) + apiType, _ := common.ChannelType2APIType(channelType) info := &RelayInfo{ - UserQuota: c.GetInt(constant.ContextKeyUserQuota), - UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), - UserEmail: c.GetString(constant.ContextKeyUserEmail), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: c.GetString("base_url"), + BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl), RequestURLPath: c.Request.URL.String(), ChannelType: channelType, ChannelId: channelId, TokenId: tokenId, TokenKey: tokenKey, UserId: userId, - UsingGroup: c.GetString(constant.ContextKeyUsingGroup), - UserGroup: c.GetString(constant.ContextKeyUserGroup), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), TokenUnlimited: tokenUnlimited, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: c.GetString("original_model"), - UpstreamModelName: c.GetString("original_model"), + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), //RecodeModelName: c.GetString("original_model"), IsModelMapped: false, ApiType: apiType, @@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info.RequestURLPath = "/v1" + info.RequestURLPath } if info.BaseUrl == "" { - info.BaseUrl = common.ChannelBaseURLs[channelType] + info.BaseUrl = constant.ChannelBaseURLs[channelType] } - if info.ChannelType == common.ChannelTypeAzure { + if info.ChannelType == constant.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } - if info.ChannelType == common.ChannelTypeVertexAi { + if info.ChannelType == constant.ChannelTypeVertexAi { info.ApiVersion = c.GetString("region") } if streamSupportedChannels[info.ChannelType] { diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 7a4f44bb..29086585 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -6,7 +6,7 @@ import ( _ "image/gif" _ "image/jpeg" _ "image/png" - "one-api/common" + "one-api/constant" "strings" ) @@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { - case common.ChannelTypeOpenAI: + case constant.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index d7033846..0df219e3 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" @@ -21,7 +22,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo println("reranker response body: ", string(responseBody)) } var jinaResp dto.RerankResponse - if info.ChannelType == common.ChannelTypeXinference { + if info.ChannelType == constant.ChannelTypeXinference { var xinRerankResponse xinference.XinRerankResponse err = common.UnmarshalJson(responseBody, &xinRerankResponse) if err != nil { diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index 3f1ecd78..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,106 +0,0 @@ -package constant - -import ( - "one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - APITypeZhipuV4 - APITypeOllama - APITypePerplexity - APITypeAws - APITypeCohere - APITypeDify - APITypeJina - APITypeCloudflare - APITypeSiliconFlow - APITypeVertexAi - APITypeMistral - APITypeDeepSeek - APITypeMokaAI - APITypeVolcEngine - APITypeBaiduV2 - APITypeOpenRouter - APITypeXinference - APITypeXai - APITypeCoze - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) (int, bool) { - apiType := -1 - switch channelType { - case common.ChannelTypeOpenAI: - apiType = APITypeOpenAI - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - case common.ChannelTypeZhipu_v4: - apiType = APITypeZhipuV4 - case common.ChannelTypeOllama: - apiType = APITypeOllama - case common.ChannelTypePerplexity: - apiType = APITypePerplexity - case common.ChannelTypeAws: - apiType = APITypeAws - case common.ChannelTypeCohere: - apiType = APITypeCohere - case common.ChannelTypeDify: - apiType = APITypeDify - case common.ChannelTypeJina: - apiType = APITypeJina - case common.ChannelCloudflare: - apiType = APITypeCloudflare - case common.ChannelTypeSiliconFlow: - apiType = APITypeSiliconFlow - case common.ChannelTypeVertexAi: - apiType = APITypeVertexAi - case common.ChannelTypeMistral: - apiType = APITypeMistral - case common.ChannelTypeDeepSeek: - apiType = APITypeDeepSeek - case common.ChannelTypeMokaAI: - apiType = APITypeMokaAI - case common.ChannelTypeVolcEngine: - apiType = APITypeVolcEngine - case common.ChannelTypeBaiduV2: - apiType = APITypeBaiduV2 - case common.ChannelTypeOpenRouter: - apiType = APITypeOpenRouter - case common.ChannelTypeXinference: - apiType = APITypeXinference - case common.ChannelTypeXai: - apiType = APITypeXai - case common.ChannelTypeCoze: - apiType = APITypeCoze - } - if apiType == -1 { - return APITypeOpenAI, false - } - return apiType, true -} diff --git a/relay/image_handler.go b/relay/image_handler.go index 15a42e79..5decb497 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -17,8 +18,6 @@ import ( "one-api/setting" "strings" - "one-api/relay/constant" - "github.com/gin-gonic/gin" ) diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index f648b4d5..00e59eac 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,6 +1,7 @@ package relay import ( + "one-api/constant" commonconstant "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" @@ -32,7 +33,6 @@ import ( "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" - "one-api/relay/constant" ) func GetAdaptor(apiType int) channel.Adaptor { diff --git a/service/channel.go b/service/channel.go index 746e9a34..d50de78d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/setting/operation_setting" @@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b } if err.StatusCode == http.StatusForbidden { switch channelType { - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: return true } } diff --git a/service/convert.go b/service/convert.go index df7acf0d..c97f8475 100644 --- a/service/convert.go +++ b/service/convert.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" @@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re Stream: claudeRequest.Stream, } - isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter + isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" { if isOpenRouter { diff --git a/service/quota.go b/service/quota.go index c17616a7..bc3ef296 100644 --- a/service/quota.go +++ b/service/quota.go @@ -6,7 +6,7 @@ import ( "log" "math" "one-api/common" - constant2 "one-api/constant" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, cacheCreationRatio := priceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens - if relayInfo.ChannelType == common.ChannelTypeOpenRouter { + if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { promptTokens -= cacheTokens if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 { maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) @@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon gopool.Go(func() { userSetting := relayInfo.UserSetting threshold := common.QuotaRemindThreshold - if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { + if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok { threshold = int(userCustomThreshold.(float64)) } diff --git a/service/token_counter.go b/service/token_counter.go index 53c6c2fa..302d6c1a 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic { + if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { return 3 * baseTokens, nil } var config image.Config diff --git a/types/set.go b/types/set.go new file mode 100644 index 00000000..db6b0272 --- /dev/null +++ b/types/set.go @@ -0,0 +1,42 @@ +package types + +type Set[T comparable] struct { + items map[T]struct{} +} + +// NewSet 创建并返回一个新的 Set +func NewSet[T comparable]() *Set[T] { + return &Set[T]{ + items: make(map[T]struct{}), + } +} + +func (s *Set[T]) Add(item T) { + s.items[item] = struct{}{} +} + +// Remove 从 Set 中移除一个元素 +func (s *Set[T]) Remove(item T) { + delete(s.items, item) +} + +// Contains 检查 Set 是否包含某个元素 +func (s *Set[T]) Contains(item T) bool { + _, exists := s.items[item] + return exists +} + +// Len 返回 Set 中元素的数量 +func (s *Set[T]) Len() int { + return len(s.items) +} + +// Items 返回 Set 中所有元素组成的切片 +// 注意:由于 map 的无序性,返回的切片元素顺序是随机的 +func (s *Set[T]) Items() []T { + items := make([]T, 0, s.Len()) + for item := range s.items { + items = append(items, item) + } + return items +}