Merge pull request #1326 from QuantumNous/refactor_constant
✨ feat: refactor environment variable initialization
This commit is contained in:
71
common/api_type.go
Normal file
71
common/api_type.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
func ChannelType2APIType(channelType int) (int, bool) {
|
||||||
|
apiType := -1
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeOpenAI:
|
||||||
|
apiType = constant.APITypeOpenAI
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
apiType = constant.APITypeAnthropic
|
||||||
|
case constant.ChannelTypeBaidu:
|
||||||
|
apiType = constant.APITypeBaidu
|
||||||
|
case constant.ChannelTypePaLM:
|
||||||
|
apiType = constant.APITypePaLM
|
||||||
|
case constant.ChannelTypeZhipu:
|
||||||
|
apiType = constant.APITypeZhipu
|
||||||
|
case constant.ChannelTypeAli:
|
||||||
|
apiType = constant.APITypeAli
|
||||||
|
case constant.ChannelTypeXunfei:
|
||||||
|
apiType = constant.APITypeXunfei
|
||||||
|
case constant.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = constant.APITypeAIProxyLibrary
|
||||||
|
case constant.ChannelTypeTencent:
|
||||||
|
apiType = constant.APITypeTencent
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
apiType = constant.APITypeGemini
|
||||||
|
case constant.ChannelTypeZhipu_v4:
|
||||||
|
apiType = constant.APITypeZhipuV4
|
||||||
|
case constant.ChannelTypeOllama:
|
||||||
|
apiType = constant.APITypeOllama
|
||||||
|
case constant.ChannelTypePerplexity:
|
||||||
|
apiType = constant.APITypePerplexity
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
apiType = constant.APITypeAws
|
||||||
|
case constant.ChannelTypeCohere:
|
||||||
|
apiType = constant.APITypeCohere
|
||||||
|
case constant.ChannelTypeDify:
|
||||||
|
apiType = constant.APITypeDify
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
apiType = constant.APITypeJina
|
||||||
|
case constant.ChannelCloudflare:
|
||||||
|
apiType = constant.APITypeCloudflare
|
||||||
|
case constant.ChannelTypeSiliconFlow:
|
||||||
|
apiType = constant.APITypeSiliconFlow
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
apiType = constant.APITypeVertexAi
|
||||||
|
case constant.ChannelTypeMistral:
|
||||||
|
apiType = constant.APITypeMistral
|
||||||
|
case constant.ChannelTypeDeepSeek:
|
||||||
|
apiType = constant.APITypeDeepSeek
|
||||||
|
case constant.ChannelTypeMokaAI:
|
||||||
|
apiType = constant.APITypeMokaAI
|
||||||
|
case constant.ChannelTypeVolcEngine:
|
||||||
|
apiType = constant.APITypeVolcEngine
|
||||||
|
case constant.ChannelTypeBaiduV2:
|
||||||
|
apiType = constant.APITypeBaiduV2
|
||||||
|
case constant.ChannelTypeOpenRouter:
|
||||||
|
apiType = constant.APITypeOpenRouter
|
||||||
|
case constant.ChannelTypeXinference:
|
||||||
|
apiType = constant.APITypeXinference
|
||||||
|
case constant.ChannelTypeXai:
|
||||||
|
apiType = constant.APITypeXai
|
||||||
|
case constant.ChannelTypeCoze:
|
||||||
|
apiType = constant.APITypeCoze
|
||||||
|
}
|
||||||
|
if apiType == -1 {
|
||||||
|
return constant.APITypeOpenAI, false
|
||||||
|
}
|
||||||
|
return apiType, true
|
||||||
|
}
|
||||||
@@ -193,111 +193,3 @@ const (
|
|||||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
ChannelStatusAutoDisabled = 3
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ChannelTypeUnknown = 0
|
|
||||||
ChannelTypeOpenAI = 1
|
|
||||||
ChannelTypeMidjourney = 2
|
|
||||||
ChannelTypeAzure = 3
|
|
||||||
ChannelTypeOllama = 4
|
|
||||||
ChannelTypeMidjourneyPlus = 5
|
|
||||||
ChannelTypeOpenAIMax = 6
|
|
||||||
ChannelTypeOhMyGPT = 7
|
|
||||||
ChannelTypeCustom = 8
|
|
||||||
ChannelTypeAILS = 9
|
|
||||||
ChannelTypeAIProxy = 10
|
|
||||||
ChannelTypePaLM = 11
|
|
||||||
ChannelTypeAPI2GPT = 12
|
|
||||||
ChannelTypeAIGC2D = 13
|
|
||||||
ChannelTypeAnthropic = 14
|
|
||||||
ChannelTypeBaidu = 15
|
|
||||||
ChannelTypeZhipu = 16
|
|
||||||
ChannelTypeAli = 17
|
|
||||||
ChannelTypeXunfei = 18
|
|
||||||
ChannelType360 = 19
|
|
||||||
ChannelTypeOpenRouter = 20
|
|
||||||
ChannelTypeAIProxyLibrary = 21
|
|
||||||
ChannelTypeFastGPT = 22
|
|
||||||
ChannelTypeTencent = 23
|
|
||||||
ChannelTypeGemini = 24
|
|
||||||
ChannelTypeMoonshot = 25
|
|
||||||
ChannelTypeZhipu_v4 = 26
|
|
||||||
ChannelTypePerplexity = 27
|
|
||||||
ChannelTypeLingYiWanWu = 31
|
|
||||||
ChannelTypeAws = 33
|
|
||||||
ChannelTypeCohere = 34
|
|
||||||
ChannelTypeMiniMax = 35
|
|
||||||
ChannelTypeSunoAPI = 36
|
|
||||||
ChannelTypeDify = 37
|
|
||||||
ChannelTypeJina = 38
|
|
||||||
ChannelCloudflare = 39
|
|
||||||
ChannelTypeSiliconFlow = 40
|
|
||||||
ChannelTypeVertexAi = 41
|
|
||||||
ChannelTypeMistral = 42
|
|
||||||
ChannelTypeDeepSeek = 43
|
|
||||||
ChannelTypeMokaAI = 44
|
|
||||||
ChannelTypeVolcEngine = 45
|
|
||||||
ChannelTypeBaiduV2 = 46
|
|
||||||
ChannelTypeXinference = 47
|
|
||||||
ChannelTypeXai = 48
|
|
||||||
ChannelTypeCoze = 49
|
|
||||||
ChannelTypeKling = 50
|
|
||||||
ChannelTypeJimeng = 51
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
|
||||||
"", // 0
|
|
||||||
"https://api.openai.com", // 1
|
|
||||||
"https://oa.api2d.net", // 2
|
|
||||||
"", // 3
|
|
||||||
"http://localhost:11434", // 4
|
|
||||||
"https://api.openai-sb.com", // 5
|
|
||||||
"https://api.openaimax.com", // 6
|
|
||||||
"https://api.ohmygpt.com", // 7
|
|
||||||
"", // 8
|
|
||||||
"https://api.caipacity.com", // 9
|
|
||||||
"https://api.aiproxy.io", // 10
|
|
||||||
"", // 11
|
|
||||||
"https://api.api2gpt.com", // 12
|
|
||||||
"https://api.aigc2d.com", // 13
|
|
||||||
"https://api.anthropic.com", // 14
|
|
||||||
"https://aip.baidubce.com", // 15
|
|
||||||
"https://open.bigmodel.cn", // 16
|
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
|
||||||
"", // 18
|
|
||||||
"https://api.360.cn", // 19
|
|
||||||
"https://openrouter.ai/api", // 20
|
|
||||||
"https://api.aiproxy.io", // 21
|
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
|
||||||
"https://hunyuan.tencentcloudapi.com", //23
|
|
||||||
"https://generativelanguage.googleapis.com", //24
|
|
||||||
"https://api.moonshot.cn", //25
|
|
||||||
"https://open.bigmodel.cn", //26
|
|
||||||
"https://api.perplexity.ai", //27
|
|
||||||
"", //28
|
|
||||||
"", //29
|
|
||||||
"", //30
|
|
||||||
"https://api.lingyiwanwu.com", //31
|
|
||||||
"", //32
|
|
||||||
"", //33
|
|
||||||
"https://api.cohere.ai", //34
|
|
||||||
"https://api.minimax.chat", //35
|
|
||||||
"", //36
|
|
||||||
"https://api.dify.ai", //37
|
|
||||||
"https://api.jina.ai", //38
|
|
||||||
"https://api.cloudflare.com", //39
|
|
||||||
"https://api.siliconflow.cn", //40
|
|
||||||
"", //41
|
|
||||||
"https://api.mistral.ai", //42
|
|
||||||
"https://api.deepseek.com", //43
|
|
||||||
"https://api.moka.ai", //44
|
|
||||||
"https://ark.cn-beijing.volces.com", //45
|
|
||||||
"https://qianfan.baidubce.com", //46
|
|
||||||
"", //47
|
|
||||||
"https://api.x.ai", //48
|
|
||||||
"https://api.coze.cn", //49
|
|
||||||
"https://api.klingai.com", //50
|
|
||||||
"https://visual.volcengineapi.com", //51
|
|
||||||
}
|
|
||||||
|
|||||||
29
common/endpoint_type.go
Normal file
29
common/endpoint_type.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||||
|
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||||
|
var endpointTypes []constant.EndpointType
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
default:
|
||||||
|
if IsOpenAIResponseOnlyModel(modelName) {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||||
|
} else {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return endpointTypes
|
||||||
|
}
|
||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const KeyRequestBody = "key_request_body"
|
const KeyRequestBody = "key_request_body"
|
||||||
@@ -42,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||||
|
c.Set(string(key), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||||
|
return c.Get(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||||
|
return c.GetString(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||||
|
return c.GetInt(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||||
|
return c.GetBool(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||||
|
return c.GetStringSlice(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||||
|
return c.GetStringMap(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||||
|
return c.GetTime(string(key))
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"one-api/constant"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -24,7 +25,7 @@ func printHelp() {
|
|||||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitCommonEnv() {
|
func InitEnv() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
@@ -95,4 +96,25 @@ func InitCommonEnv() {
|
|||||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||||
|
|
||||||
|
initConstantEnv()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initConstantEnv() {
|
||||||
|
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||||
|
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||||
|
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||||
|
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||||
|
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||||
|
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||||
|
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||||
|
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||||
|
// 是否启用错误日志
|
||||||
|
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||||
}
|
}
|
||||||
|
|||||||
21
common/model.go
Normal file
21
common/model.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||||
|
OpenAIResponseOnlyModels = []string{
|
||||||
|
"o3-pro",
|
||||||
|
"o3-deep-research",
|
||||||
|
"o4-mini-deep-research",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||||
|
for _, m := range OpenAIResponseOnlyModels {
|
||||||
|
if strings.Contains(m, modelName) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -16,6 +16,10 @@ import (
|
|||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
var RedisEnabled = true
|
var RedisEnabled = true
|
||||||
|
|
||||||
|
func RedisKeyCacheSeconds() int {
|
||||||
|
return SyncFrequency
|
||||||
|
}
|
||||||
|
|
||||||
// InitRedisClient This function is called after init()
|
// InitRedisClient This function is called after init()
|
||||||
func InitRedisClient() (err error) {
|
func InitRedisClient() (err error) {
|
||||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||||
|
|||||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# constant 包 (`/constant`)
|
||||||
|
|
||||||
|
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||||
|
|
||||||
|
## 当前文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|----------------------|---------------------------------------------------------------------|
|
||||||
|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||||
|
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||||
|
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||||
|
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||||
|
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||||
|
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||||
|
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||||
|
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||||
|
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||||
|
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||||
|
|
||||||
|
## 使用约定
|
||||||
|
|
||||||
|
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||||
|
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||||
|
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||||
|
|
||||||
|
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||||
34
constant/api_type.go
Normal file
34
constant/api_type.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
APITypeOpenAI = iota
|
||||||
|
APITypeAnthropic
|
||||||
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
|
APITypeZhipu
|
||||||
|
APITypeAli
|
||||||
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
|
APITypeGemini
|
||||||
|
APITypeZhipuV4
|
||||||
|
APITypeOllama
|
||||||
|
APITypePerplexity
|
||||||
|
APITypeAws
|
||||||
|
APITypeCohere
|
||||||
|
APITypeDify
|
||||||
|
APITypeJina
|
||||||
|
APITypeCloudflare
|
||||||
|
APITypeSiliconFlow
|
||||||
|
APITypeVertexAi
|
||||||
|
APITypeMistral
|
||||||
|
APITypeDeepSeek
|
||||||
|
APITypeMokaAI
|
||||||
|
APITypeVolcEngine
|
||||||
|
APITypeBaiduV2
|
||||||
|
APITypeOpenRouter
|
||||||
|
APITypeXinference
|
||||||
|
APITypeXai
|
||||||
|
APITypeCoze
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
)
|
||||||
@@ -1,12 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import "one-api/common"
|
|
||||||
|
|
||||||
// 使用函数来避免初始化顺序带来的赋值问题
|
|
||||||
func RedisKeyCacheSeconds() int {
|
|
||||||
return common.SyncFrequency
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache keys
|
// Cache keys
|
||||||
const (
|
const (
|
||||||
UserGroupKeyFmt = "user_group:%d"
|
UserGroupKeyFmt = "user_group:%d"
|
||||||
|
|||||||
109
constant/channel.go
Normal file
109
constant/channel.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChannelTypeUnknown = 0
|
||||||
|
ChannelTypeOpenAI = 1
|
||||||
|
ChannelTypeMidjourney = 2
|
||||||
|
ChannelTypeAzure = 3
|
||||||
|
ChannelTypeOllama = 4
|
||||||
|
ChannelTypeMidjourneyPlus = 5
|
||||||
|
ChannelTypeOpenAIMax = 6
|
||||||
|
ChannelTypeOhMyGPT = 7
|
||||||
|
ChannelTypeCustom = 8
|
||||||
|
ChannelTypeAILS = 9
|
||||||
|
ChannelTypeAIProxy = 10
|
||||||
|
ChannelTypePaLM = 11
|
||||||
|
ChannelTypeAPI2GPT = 12
|
||||||
|
ChannelTypeAIGC2D = 13
|
||||||
|
ChannelTypeAnthropic = 14
|
||||||
|
ChannelTypeBaidu = 15
|
||||||
|
ChannelTypeZhipu = 16
|
||||||
|
ChannelTypeAli = 17
|
||||||
|
ChannelTypeXunfei = 18
|
||||||
|
ChannelType360 = 19
|
||||||
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
|
ChannelTypeFastGPT = 22
|
||||||
|
ChannelTypeTencent = 23
|
||||||
|
ChannelTypeGemini = 24
|
||||||
|
ChannelTypeMoonshot = 25
|
||||||
|
ChannelTypeZhipu_v4 = 26
|
||||||
|
ChannelTypePerplexity = 27
|
||||||
|
ChannelTypeLingYiWanWu = 31
|
||||||
|
ChannelTypeAws = 33
|
||||||
|
ChannelTypeCohere = 34
|
||||||
|
ChannelTypeMiniMax = 35
|
||||||
|
ChannelTypeSunoAPI = 36
|
||||||
|
ChannelTypeDify = 37
|
||||||
|
ChannelTypeJina = 38
|
||||||
|
ChannelCloudflare = 39
|
||||||
|
ChannelTypeSiliconFlow = 40
|
||||||
|
ChannelTypeVertexAi = 41
|
||||||
|
ChannelTypeMistral = 42
|
||||||
|
ChannelTypeDeepSeek = 43
|
||||||
|
ChannelTypeMokaAI = 44
|
||||||
|
ChannelTypeVolcEngine = 45
|
||||||
|
ChannelTypeBaiduV2 = 46
|
||||||
|
ChannelTypeXinference = 47
|
||||||
|
ChannelTypeXai = 48
|
||||||
|
ChannelTypeCoze = 49
|
||||||
|
ChannelTypeKling = 50
|
||||||
|
ChannelTypeJimeng = 51
|
||||||
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
var ChannelBaseURLs = []string{
|
||||||
|
"", // 0
|
||||||
|
"https://api.openai.com", // 1
|
||||||
|
"https://oa.api2d.net", // 2
|
||||||
|
"", // 3
|
||||||
|
"http://localhost:11434", // 4
|
||||||
|
"https://api.openai-sb.com", // 5
|
||||||
|
"https://api.openaimax.com", // 6
|
||||||
|
"https://api.ohmygpt.com", // 7
|
||||||
|
"", // 8
|
||||||
|
"https://api.caipacity.com", // 9
|
||||||
|
"https://api.aiproxy.io", // 10
|
||||||
|
"", // 11
|
||||||
|
"https://api.api2gpt.com", // 12
|
||||||
|
"https://api.aigc2d.com", // 13
|
||||||
|
"https://api.anthropic.com", // 14
|
||||||
|
"https://aip.baidubce.com", // 15
|
||||||
|
"https://open.bigmodel.cn", // 16
|
||||||
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
|
"", // 18
|
||||||
|
"https://api.360.cn", // 19
|
||||||
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
|
"https://hunyuan.tencentcloudapi.com", //23
|
||||||
|
"https://generativelanguage.googleapis.com", //24
|
||||||
|
"https://api.moonshot.cn", //25
|
||||||
|
"https://open.bigmodel.cn", //26
|
||||||
|
"https://api.perplexity.ai", //27
|
||||||
|
"", //28
|
||||||
|
"", //29
|
||||||
|
"", //30
|
||||||
|
"https://api.lingyiwanwu.com", //31
|
||||||
|
"", //32
|
||||||
|
"", //33
|
||||||
|
"https://api.cohere.ai", //34
|
||||||
|
"https://api.minimax.chat", //35
|
||||||
|
"", //36
|
||||||
|
"https://api.dify.ai", //37
|
||||||
|
"https://api.jina.ai", //38
|
||||||
|
"https://api.cloudflare.com", //39
|
||||||
|
"https://api.siliconflow.cn", //40
|
||||||
|
"", //41
|
||||||
|
"https://api.mistral.ai", //42
|
||||||
|
"https://api.deepseek.com", //43
|
||||||
|
"https://api.moka.ai", //44
|
||||||
|
"https://ark.cn-beijing.volces.com", //45
|
||||||
|
"https://qianfan.baidubce.com", //46
|
||||||
|
"", //47
|
||||||
|
"https://api.x.ai", //48
|
||||||
|
"https://api.coze.cn", //49
|
||||||
|
"https://api.klingai.com", //50
|
||||||
|
"https://visual.volcengineapi.com", //51
|
||||||
|
}
|
||||||
@@ -1,11 +1,35 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ContextKeyRequestStartTime = "request_start_time"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyUserSetting = "user_setting"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
ContextKeyUserQuota = "user_quota"
|
|
||||||
ContextKeyUserStatus = "user_status"
|
/* token related keys */
|
||||||
ContextKeyUserEmail = "user_email"
|
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||||
ContextKeyUserGroup = "user_group"
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
ContextKeyUsingGroup = "group"
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
|
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||||
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|
||||||
|
/* channel related keys */
|
||||||
|
ContextKeyBaseUrl ContextKey = "base_url"
|
||||||
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyParamOverride ContextKey = "param_override"
|
||||||
|
|
||||||
|
/* user related keys */
|
||||||
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
ContextKeyUserSetting ContextKey = "user_setting"
|
||||||
|
ContextKeyUserQuota ContextKey = "user_quota"
|
||||||
|
ContextKeyUserStatus ContextKey = "user_status"
|
||||||
|
ContextKeyUserEmail ContextKey = "user_email"
|
||||||
|
ContextKeyUserGroup ContextKey = "user_group"
|
||||||
|
ContextKeyUsingGroup ContextKey = "group"
|
||||||
|
ContextKeyUserName ContextKey = "username"
|
||||||
)
|
)
|
||||||
|
|||||||
11
constant/endpoint_type.go
Normal file
11
constant/endpoint_type.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type EndpointType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointTypeOpenAI EndpointType = "openai"
|
||||||
|
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||||
|
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||||
|
EndpointTypeGemini EndpointType = "gemini"
|
||||||
|
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||||
|
)
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
var StreamingTimeout int
|
var StreamingTimeout int
|
||||||
var DifyDebug bool
|
var DifyDebug bool
|
||||||
var MaxFileDownloadMB int
|
var MaxFileDownloadMB int
|
||||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
|||||||
var NotificationLimitDurationMinute int
|
var NotificationLimitDurationMinute int
|
||||||
var GenerateDefaultToken bool
|
var GenerateDefaultToken bool
|
||||||
var ErrorLogEnabled bool
|
var ErrorLogEnabled bool
|
||||||
|
|
||||||
//var GeminiModelMap = map[string]string{
|
|
||||||
// "gemini-1.0-pro": "v1",
|
|
||||||
//}
|
|
||||||
|
|
||||||
func InitEnv() {
|
|
||||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
|
||||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
|
||||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
|
||||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
|
||||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
|
||||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
|
||||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
|
||||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
|
||||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
|
||||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
|
||||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
|
||||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
|
||||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
|
||||||
// 是否启用错误日志
|
|
||||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
|
||||||
|
|
||||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
|
||||||
//if modelVersionMapStr == "" {
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
|
||||||
// parts := strings.Split(pair, ":")
|
|
||||||
// if len(parts) == 2 {
|
|
||||||
// GeminiModelMap[parts[0]] = parts[1]
|
|
||||||
// } else {
|
|
||||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = &baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case constant.ChannelTypeOpenAI:
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
//case common.ChannelTypeOpenAISB:
|
//case common.ChannelTypeOpenAISB:
|
||||||
// return updateChannelOpenAISBBalance(channel)
|
// return updateChannelOpenAISBBalance(channel)
|
||||||
case common.ChannelTypeAIProxy:
|
case constant.ChannelTypeAIProxy:
|
||||||
return updateChannelAIProxyBalance(channel)
|
return updateChannelAIProxyBalance(channel)
|
||||||
case common.ChannelTypeAPI2GPT:
|
case constant.ChannelTypeAPI2GPT:
|
||||||
return updateChannelAPI2GPTBalance(channel)
|
return updateChannelAPI2GPTBalance(channel)
|
||||||
case common.ChannelTypeAIGC2D:
|
case constant.ChannelTypeAIGC2D:
|
||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
case common.ChannelTypeSiliconFlow:
|
case constant.ChannelTypeSiliconFlow:
|
||||||
return updateChannelSiliconFlowBalance(channel)
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
case common.ChannelTypeDeepSeek:
|
case constant.ChannelTypeDeepSeek:
|
||||||
return updateChannelDeepSeekBalance(channel)
|
return updateChannelDeepSeekBalance(channel)
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return updateChannelOpenRouterBalance(channel)
|
return updateChannelOpenRouterBalance(channel)
|
||||||
case common.ChannelTypeMoonshot:
|
case constant.ChannelTypeMoonshot:
|
||||||
return updateChannelMoonshotBalance(channel)
|
return updateChannelMoonshotBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -31,19 +31,19 @@ import (
|
|||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
return errors.New("midjourney plus channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeSunoAPI {
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return errors.New("suno channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeKling {
|
if channel.Type == constant.ChannelTypeKling {
|
||||||
return errors.New("kling channel test is not supported"), nil
|
return errors.New("kling channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeJimeng {
|
if channel.Type == constant.ChannelTypeJimeng {
|
||||||
return errors.New("jimeng channel test is not supported"), nil
|
return errors.New("jimeng channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -56,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||||
strings.Contains(testModel, "embed") ||
|
strings.Contains(testModel, "embed") ||
|
||||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||||
requestPath = "/v1/embeddings" // 修改请求路径
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
}
|
}
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -125,7 +126,7 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
order = "id desc"
|
order = "id desc"
|
||||||
}
|
}
|
||||||
|
|
||||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
|
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -181,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||||
case common.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||||
}
|
}
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
@@ -213,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
var ids []string
|
var ids []string
|
||||||
for _, model := range result.Data {
|
for _, model := range result.Data {
|
||||||
id := model.ID
|
id := model.ID
|
||||||
if channel.Type == common.ChannelTypeGemini {
|
if channel.Type == constant.ChannelTypeGemini {
|
||||||
id = strings.TrimPrefix(id, "models/")
|
id = strings.TrimPrefix(id, "models/")
|
||||||
}
|
}
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
@@ -388,7 +389,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
channel.CreatedTime = common.GetTimestamp()
|
||||||
keys := strings.Split(channel.Key, "\n")
|
keys := strings.Split(channel.Key, "\n")
|
||||||
if channel.Type == common.ChannelTypeVertexAi {
|
if channel.Type == constant.ChannelTypeVertexAi {
|
||||||
if channel.Other == "" {
|
if channel.Other == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -613,7 +614,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeVertexAi {
|
if channel.Type == constant.ChannelTypeVertexAi {
|
||||||
if channel.Other == "" {
|
if channel.Other == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -668,7 +669,7 @@ func FetchModels(c *gin.Context) {
|
|||||||
|
|
||||||
baseURL := req.BaseURL
|
baseURL := req.BaseURL
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = common.ChannelBaseURLs[req.Type]
|
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -14,10 +15,7 @@ import (
|
|||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
|||||||
var openAIModelsMap map[string]dto.OpenAIModels
|
var openAIModelsMap map[string]dto.OpenAIModels
|
||||||
var channelId2Models map[int][]string
|
var channelId2Models map[int][]string
|
||||||
|
|
||||||
func getPermission() []dto.OpenAIModelPermission {
|
|
||||||
var permission []dto.OpenAIModelPermission
|
|
||||||
permission = append(permission, dto.OpenAIModelPermission{
|
|
||||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
|
||||||
Object: "model_permission",
|
|
||||||
Created: 1626777600,
|
|
||||||
AllowCreateEngine: true,
|
|
||||||
AllowSampling: true,
|
|
||||||
AllowLogprobs: true,
|
|
||||||
AllowSearchIndices: false,
|
|
||||||
AllowView: true,
|
|
||||||
AllowFineTuning: false,
|
|
||||||
Organization: "*",
|
|
||||||
Group: nil,
|
|
||||||
IsBlocking: false,
|
|
||||||
})
|
|
||||||
return permission
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
permission := getPermission()
|
for i := 0; i < constant.APITypeDummy; i++ {
|
||||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
if i == constant.APITypeAIProxyLibrary {
|
||||||
if i == relayconstant.APITypeAIProxyLibrary {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
adaptor := relay.GetAdaptor(i)
|
adaptor := relay.GetAdaptor(i)
|
||||||
@@ -57,69 +35,51 @@ func init() {
|
|||||||
modelNames := adaptor.GetModelList()
|
modelNames := adaptor.GetModelList()
|
||||||
for _, modelName := range modelNames {
|
for _, modelName := range modelNames {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: channelName,
|
OwnedBy: channelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, modelName := range ai360.ModelList {
|
for _, modelName := range ai360.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: ai360.ChannelName,
|
OwnedBy: ai360.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range moonshot.ModelList {
|
for _, modelName := range moonshot.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: moonshot.ChannelName,
|
OwnedBy: moonshot.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range lingyiwanwu.ModelList {
|
for _, modelName := range lingyiwanwu.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: lingyiwanwu.ChannelName,
|
OwnedBy: lingyiwanwu.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range minimax.ModelList {
|
for _, modelName := range minimax.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: minimax.ChannelName,
|
OwnedBy: minimax.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "midjourney",
|
OwnedBy: "midjourney",
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||||
@@ -127,9 +87,9 @@ func init() {
|
|||||||
openAIModelsMap[aiModel.Id] = aiModel
|
openAIModelsMap[aiModel.Id] = aiModel
|
||||||
}
|
}
|
||||||
channelId2Models = make(map[int][]string)
|
channelId2Models = make(map[int][]string)
|
||||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
apiType, success := common.ChannelType2APIType(i)
|
||||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||||
@@ -144,11 +104,10 @@ func init() {
|
|||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||||
permission := getPermission()
|
|
||||||
|
|
||||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := c.Get("token_model_limit")
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
var tokenModelLimit map[string]bool
|
||||||
if ok {
|
if ok {
|
||||||
tokenModelLimit = s.(map[string]bool)
|
tokenModelLimit = s.(map[string]bool)
|
||||||
@@ -156,17 +115,16 @@ func ListModels(c *gin.Context) {
|
|||||||
tokenModelLimit = map[string]bool{}
|
tokenModelLimit = map[string]bool{}
|
||||||
}
|
}
|
||||||
for allowModel, _ := range tokenModelLimit {
|
for allowModel, _ := range tokenModelLimit {
|
||||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: allowModel,
|
Id: allowModel,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||||
Root: allowModel,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -181,14 +139,14 @@ func ListModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
group := userGroup
|
group := userGroup
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
group = tokenGroup
|
group = tokenGroup
|
||||||
}
|
}
|
||||||
var models []string
|
var models []string
|
||||||
if tokenGroup == "auto" {
|
if tokenGroup == "auto" {
|
||||||
for _, autoGroup := range setting.AutoGroups {
|
for _, autoGroup := range setting.AutoGroups {
|
||||||
groupModels := model.GetGroupModels(autoGroup)
|
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||||
for _, g := range groupModels {
|
for _, g := range groupModels {
|
||||||
if !common.StringsContains(models, g) {
|
if !common.StringsContains(models, g) {
|
||||||
models = append(models, g)
|
models = append(models, g)
|
||||||
@@ -196,20 +154,19 @@ func ListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
models = model.GetGroupModels(group)
|
models = model.GetGroupEnabledModels(group)
|
||||||
}
|
}
|
||||||
for _, s := range models {
|
for _, modelName := range models {
|
||||||
if _, ok := openAIModelsMap[s]; ok {
|
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: s,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||||
Root: s,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func Playground(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
// Write user context to ensure acceptUnsetRatio is available
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
constant2 "one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/constant"
|
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||||
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
|||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
channelType := c.GetInt("channel_type")
|
channelType := c.GetInt("channel_type")
|
||||||
if channelType == common.ChannelTypeAnthropic {
|
if channelType == constant.ChannelTypeAnthropic {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -487,7 +487,7 @@ func GetUserModels(c *gin.Context) {
|
|||||||
groups := setting.GetUserUsableGroups(user.Group)
|
groups := setting.GetUserUsableGroups(user.Group)
|
||||||
var models []string
|
var models []string
|
||||||
for group := range groups {
|
for group := range groups {
|
||||||
for _, g := range model.GetGroupModels(group) {
|
for _, g := range model.GetGroupEnabledModels(group) {
|
||||||
if !common.StringsContains(models, g) {
|
if !common.StringsContains(models, g) {
|
||||||
models = append(models, g)
|
models = append(models, g)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,11 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type OpenAIModelPermission struct {
|
import "one-api/constant"
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
||||||
AllowSampling bool `json:"allow_sampling"`
|
|
||||||
AllowLogprobs bool `json:"allow_logprobs"`
|
|
||||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
||||||
AllowView bool `json:"allow_view"`
|
|
||||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
||||||
Organization string `json:"organization"`
|
|
||||||
Group *string `json:"group"`
|
|
||||||
IsBlocking bool `json:"is_blocking"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int `json:"created"`
|
Created int `json:"created"`
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
Root string `json:"root"`
|
|
||||||
Parent *string `json:"parent"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
9
main.go
9
main.go
@@ -169,10 +169,8 @@ func InitResources() error {
|
|||||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载旧的(common)环境变量
|
// 加载环境变量
|
||||||
common.InitCommonEnv()
|
common.InitEnv()
|
||||||
// 加载constants的环境变量
|
|
||||||
constant.InitEnv()
|
|
||||||
|
|
||||||
// Initialize model settings
|
// Initialize model settings
|
||||||
ratio_setting.InitRatioSettings()
|
ratio_setting.InitRatioSettings()
|
||||||
@@ -193,6 +191,9 @@ func InitResources() error {
|
|||||||
// Initialize options, should after model.InitDB()
|
// Initialize options, should after model.InitDB()
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
|
|
||||||
|
// 初始化模型
|
||||||
|
model.GetPricing()
|
||||||
|
|
||||||
// Initialize SQL Database
|
// Initialize SQL Database
|
||||||
err = model.InitLogDB()
|
err = model.InitLogDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
allowIpsMap := c.GetStringMap("allow_ips")
|
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||||
if len(allowIpsMap) != 0 {
|
if len(allowIpsMap) != 0 {
|
||||||
clientIp := c.ClientIP()
|
clientIp := c.ClientIP()
|
||||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||||
@@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("specific_channel_id")
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
// check common.UserUsableGroups[userGroup]
|
// check common.UserUsableGroups[userGroup]
|
||||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
userGroup = tokenGroup
|
userGroup = tokenGroup
|
||||||
}
|
}
|
||||||
c.Set(constant.ContextKeyUsingGroup, userGroup)
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
// check token model mapping
|
// check token model mapping
|
||||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := c.Get("token_model_limit")
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
var tokenModelLimit map[string]bool
|
||||||
if ok {
|
if ok {
|
||||||
tokenModelLimit = s.(map[string]bool)
|
tokenModelLimit = s.(map[string]bool)
|
||||||
@@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
@@ -261,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeVertexAi:
|
case constant.ChannelTypeVertexAi:
|
||||||
c.Set("region", channel.Other)
|
c.Set("region", channel.Other)
|
||||||
case common.ChannelTypeXunfei:
|
case constant.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
case common.ChannelCloudflare:
|
case constant.ChannelCloudflare:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeMokaAI:
|
case constant.ChannelTypeMokaAI:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeCoze:
|
case constant.ChannelTypeCoze:
|
||||||
c.Set("bot_id", channel.Other)
|
c.Set("bot_id", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
|||||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||||
|
|
||||||
// 获取分组
|
// 获取分组
|
||||||
group := c.GetString("token_group")
|
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if group == "" {
|
if group == "" {
|
||||||
group = c.GetString(constant.ContextKeyUserGroup)
|
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
//获取分组的限流配置
|
//获取分组的限流配置
|
||||||
|
|||||||
@@ -21,7 +21,22 @@ type Ability struct {
|
|||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGroupModels(group string) []string {
|
type AbilityWithChannel struct {
|
||||||
|
Ability
|
||||||
|
ChannelType int `json:"channel_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||||
|
var abilities []AbilityWithChannel
|
||||||
|
err := DB.Table("abilities").
|
||||||
|
Select("abilities.*, channels.type as channel_type").
|
||||||
|
Joins("left join channels on abilities.channel_id = channels.id").
|
||||||
|
Where("abilities.enabled = ?", true).
|
||||||
|
Scan(&abilities).Error
|
||||||
|
return abilities, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGroupEnabledModels(group string) []string {
|
||||||
var models []string
|
var models []string
|
||||||
// Find distinct models
|
// Find distinct models
|
||||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||||
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
var priorities []int
|
var priorities []int
|
||||||
err := DB.Model(&Ability{}).
|
err := DB.Model(&Ability{}).
|
||||||
Select("DISTINCT(priority)").
|
Select("DISTINCT(priority)").
|
||||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||||
Order("priority DESC"). // 按优先级降序排序
|
Order("priority DESC"). // 按优先级降序排序
|
||||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||||
|
|
||||||
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||||
if retry != 0 {
|
if retry != 0 {
|
||||||
priority, err := getPriority(group, model, retry)
|
priority, err := getPriority(group, model, retry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
110
model/pricing.go
110
model/pricing.go
@@ -1,20 +1,24 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/types"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Pricing struct {
|
type Pricing struct {
|
||||||
ModelName string `json:"model_name"`
|
ModelName string `json:"model_name"`
|
||||||
QuotaType int `json:"quota_type"`
|
QuotaType int `json:"quota_type"`
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
ModelPrice float64 `json:"model_price"`
|
ModelPrice float64 `json:"model_price"`
|
||||||
OwnerBy string `json:"owner_by"`
|
OwnerBy string `json:"owner_by"`
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
EnableGroup []string `json:"enable_groups"`
|
||||||
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -23,47 +27,89 @@ var (
|
|||||||
updatePricingLock sync.Mutex
|
updatePricingLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPricing() []Pricing {
|
var (
|
||||||
updatePricingLock.Lock()
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
defer updatePricingLock.Unlock()
|
modelSupportEndpointsLock = sync.RWMutex{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetPricing() []Pricing {
|
||||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
updatePricing()
|
updatePricingLock.Lock()
|
||||||
|
defer updatePricingLock.Unlock()
|
||||||
|
// Double check after acquiring the lock
|
||||||
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
|
modelSupportEndpointsLock.Lock()
|
||||||
|
defer modelSupportEndpointsLock.Unlock()
|
||||||
|
updatePricing()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
//if group != "" {
|
|
||||||
// userPricingMap := make([]Pricing, 0)
|
|
||||||
// models := GetGroupModels(group)
|
|
||||||
// for _, pricing := range pricingMap {
|
|
||||||
// if !common.StringsContains(models, pricing.ModelName) {
|
|
||||||
// pricing.Available = false
|
|
||||||
// }
|
|
||||||
// userPricingMap = append(userPricingMap, pricing)
|
|
||||||
// }
|
|
||||||
// return userPricingMap
|
|
||||||
//}
|
|
||||||
return pricingMap
|
return pricingMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||||
|
if model == "" {
|
||||||
|
return make([]constant.EndpointType, 0)
|
||||||
|
}
|
||||||
|
modelSupportEndpointsLock.RLock()
|
||||||
|
defer modelSupportEndpointsLock.RUnlock()
|
||||||
|
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||||
|
return endpoints
|
||||||
|
}
|
||||||
|
return make([]constant.EndpointType, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func updatePricing() {
|
func updatePricing() {
|
||||||
//modelRatios := common.GetModelRatios()
|
//modelRatios := common.GetModelRatios()
|
||||||
enableAbilities := GetAllEnableAbilities()
|
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||||
modelGroupsMap := make(map[string][]string)
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelGroupsMap := make(map[string]*types.Set[string])
|
||||||
|
|
||||||
for _, ability := range enableAbilities {
|
for _, ability := range enableAbilities {
|
||||||
groups := modelGroupsMap[ability.Model]
|
groups, ok := modelGroupsMap[ability.Model]
|
||||||
if groups == nil {
|
if !ok {
|
||||||
groups = make([]string, 0)
|
groups = types.NewSet[string]()
|
||||||
|
modelGroupsMap[ability.Model] = groups
|
||||||
}
|
}
|
||||||
if !common.StringsContains(groups, ability.Group) {
|
groups.Add(ability.Group)
|
||||||
groups = append(groups, ability.Group)
|
}
|
||||||
|
|
||||||
|
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||||
|
modelSupportEndpointsStr := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, ability := range enableAbilities {
|
||||||
|
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||||
|
if !ok {
|
||||||
|
endpoints = make([]string, 0)
|
||||||
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
}
|
}
|
||||||
modelGroupsMap[ability.Model] = groups
|
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||||
|
for _, channelType := range channelTypes {
|
||||||
|
if !common.StringsContains(endpoints, string(channelType)) {
|
||||||
|
endpoints = append(endpoints, string(channelType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
|
for model, endpoints := range modelSupportEndpointsStr {
|
||||||
|
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||||
|
for _, endpointStr := range endpoints {
|
||||||
|
endpointType := constant.EndpointType(endpointStr)
|
||||||
|
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||||
|
}
|
||||||
|
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
pricingMap = make([]Pricing, 0)
|
pricingMap = make([]Pricing, 0)
|
||||||
for model, groups := range modelGroupsMap {
|
for model, groups := range modelGroupsMap {
|
||||||
pricing := Pricing{
|
pricing := Pricing{
|
||||||
ModelName: model,
|
ModelName: model,
|
||||||
EnableGroup: groups,
|
EnableGroup: groups.Items(),
|
||||||
|
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||||
}
|
}
|
||||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||||
if findPrice {
|
if findPrice {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
func cacheSetToken(token Token) error {
|
func cacheSetToken(token Token) error {
|
||||||
key := common.GenerateHMAC(token.Key)
|
key := common.GenerateHMAC(token.Key)
|
||||||
token.Clean()
|
token.Clean()
|
||||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,12 +24,12 @@ type UserBase struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||||
c.Set("username", user.Username)
|
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
|||||||
return common.RedisHSetObj(
|
return common.RedisHSetObj(
|
||||||
getUserCacheKey(user.Id),
|
getUserCacheKey(user.Id),
|
||||||
user.ToBaseUser(),
|
user.ToBaseUser(),
|
||||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"one-api/common"
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
@@ -21,7 +20,7 @@ import (
|
|||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/common_handler"
|
"one-api/relay/common_handler"
|
||||||
"one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
|
|
||||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||||
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||||
IsFirstThinkingContent: true,
|
IsFirstThinkingContent: true,
|
||||||
SendLastThinkingContent: false,
|
SendLastThinkingContent: false,
|
||||||
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||||
baseUrl = "wss://" + baseUrl
|
baseUrl = "wss://" + baseUrl
|
||||||
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch info.ChannelType {
|
switch info.ChannelType {
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
apiVersion := info.ApiVersion
|
apiVersion := info.ApiVersion
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = constant2.AzureDefaultAPIVersion
|
apiVersion = constant.AzureDefaultAPIVersion
|
||||||
}
|
}
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||||
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
|
||||||
// 特殊处理 responses API
|
// 特殊处理 responses API
|
||||||
if info.RelayMode == constant.RelayModeResponses {
|
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
model_ := info.UpstreamModelName
|
model_ := info.UpstreamModelName
|
||||||
// 2025年5月10日后创建的渠道不移除.
|
// 2025年5月10日后创建的渠道不移除.
|
||||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
}
|
}
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||||
}
|
}
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.GetRequestURL(info)
|
return minimax.GetRequestURL(info)
|
||||||
case common.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
url := info.BaseUrl
|
url := info.BaseUrl
|
||||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||||
return url, nil
|
return url, nil
|
||||||
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, header)
|
channel.SetupApiRequestHeader(info, c, header)
|
||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == constant.ChannelTypeAzure {
|
||||||
header.Set("api-key", info.ApiKey)
|
header.Set("api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
||||||
header.Set("OpenAI-Organization", info.Organization)
|
header.Set("OpenAI-Organization", info.Organization)
|
||||||
}
|
}
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||||
if swp != "" {
|
if swp != "" {
|
||||||
items := []string{
|
items := []string{
|
||||||
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
} else {
|
} else {
|
||||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||||
header.Set("X-Title", "New API")
|
header.Set("X-Title", "New API")
|
||||||
}
|
}
|
||||||
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
|
||||||
request.StreamOptions = nil
|
request.StreamOptions = nil
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
if len(request.Usage) == 0 {
|
if len(request.Usage) == 0 {
|
||||||
request.Usage = json.RawMessage(`{"include":true}`)
|
request.Usage = json.RawMessage(`{"include":true}`)
|
||||||
}
|
}
|
||||||
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
a.ResponseFormat = request.ResponseFormat
|
a.ResponseFormat = request.ResponseFormat
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
jsonData, err := json.Marshal(request)
|
jsonData, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||||
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesEdits:
|
||||||
|
|
||||||
var requestBody bytes.Buffer
|
var requestBody bytes.Buffer
|
||||||
writer := multipart.NewWriter(&requestBody)
|
writer := multipart.NewWriter(&requestBody)
|
||||||
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
|
||||||
info.RelayMode == constant.RelayModeAudioTranslation ||
|
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
|
||||||
info.RelayMode == constant.RelayModeImagesEdits {
|
info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||||
return channel.DoFormRequest(a, c, info, requestBody)
|
return channel.DoFormRequest(a, c, info, requestBody)
|
||||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
} else if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
return channel.DoWssRequest(a, c, info, requestBody)
|
return channel.DoWssRequest(a, c, info, requestBody)
|
||||||
} else {
|
} else {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRealtime:
|
case relayconstant.RelayModeRealtime:
|
||||||
err, usage = OpenaiRealtimeHandler(c, info)
|
err, usage = OpenaiRealtimeHandler(c, info)
|
||||||
case constant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case constant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||||
case constant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
func (a *Adaptor) GetModelList() []string {
|
||||||
switch a.ChannelType {
|
switch a.ChannelType {
|
||||||
case common.ChannelType360:
|
case constant.ChannelType360:
|
||||||
return ai360.ModelList
|
return ai360.ModelList
|
||||||
case common.ChannelTypeMoonshot:
|
case constant.ChannelTypeMoonshot:
|
||||||
return moonshot.ModelList
|
return moonshot.ModelList
|
||||||
case common.ChannelTypeLingYiWanWu:
|
case constant.ChannelTypeLingYiWanWu:
|
||||||
return lingyiwanwu.ModelList
|
return lingyiwanwu.ModelList
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.ModelList
|
return minimax.ModelList
|
||||||
case common.ChannelTypeXinference:
|
case constant.ChannelTypeXinference:
|
||||||
return xinference.ModelList
|
return xinference.ModelList
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return openrouter.ModelList
|
return openrouter.ModelList
|
||||||
default:
|
default:
|
||||||
return ModelList
|
return ModelList
|
||||||
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
|
|||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
func (a *Adaptor) GetChannelName() string {
|
||||||
switch a.ChannelType {
|
switch a.ChannelType {
|
||||||
case common.ChannelType360:
|
case constant.ChannelType360:
|
||||||
return ai360.ChannelName
|
return ai360.ChannelName
|
||||||
case common.ChannelTypeMoonshot:
|
case constant.ChannelTypeMoonshot:
|
||||||
return moonshot.ChannelName
|
return moonshot.ChannelName
|
||||||
case common.ChannelTypeLingYiWanWu:
|
case constant.ChannelTypeLingYiWanWu:
|
||||||
return lingyiwanwu.ChannelName
|
return lingyiwanwu.ChannelName
|
||||||
case common.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.ChannelName
|
return minimax.ChannelName
|
||||||
case common.ChannelTypeXinference:
|
case constant.ChannelTypeXinference:
|
||||||
return xinference.ChannelName
|
return xinference.ChannelName
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return openrouter.ChannelName
|
return openrouter.ChannelName
|
||||||
default:
|
default:
|
||||||
return ChannelName
|
return ChannelName
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
if info.ChannelType == constant.ChannelTypeDeepSeek {
|
||||||
if usage.PromptCacheHitTokens != 0 {
|
if usage.PromptCacheHitTokens != 0 {
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,17 +113,17 @@ type RelayInfo struct {
|
|||||||
|
|
||||||
// 定义支持流式选项的通道类型
|
// 定义支持流式选项的通道类型
|
||||||
var streamSupportedChannels = map[int]bool{
|
var streamSupportedChannels = map[int]bool{
|
||||||
common.ChannelTypeOpenAI: true,
|
constant.ChannelTypeOpenAI: true,
|
||||||
common.ChannelTypeAnthropic: true,
|
constant.ChannelTypeAnthropic: true,
|
||||||
common.ChannelTypeAws: true,
|
constant.ChannelTypeAws: true,
|
||||||
common.ChannelTypeGemini: true,
|
constant.ChannelTypeGemini: true,
|
||||||
common.ChannelCloudflare: true,
|
constant.ChannelCloudflare: true,
|
||||||
common.ChannelTypeAzure: true,
|
constant.ChannelTypeAzure: true,
|
||||||
common.ChannelTypeVolcEngine: true,
|
constant.ChannelTypeVolcEngine: true,
|
||||||
common.ChannelTypeOllama: true,
|
constant.ChannelTypeOllama: true,
|
||||||
common.ChannelTypeXai: true,
|
constant.ChannelTypeXai: true,
|
||||||
common.ChannelTypeDeepSeek: true,
|
constant.ChannelTypeDeepSeek: true,
|
||||||
common.ChannelTypeBaiduV2: true,
|
constant.ChannelTypeBaiduV2: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||||
@@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel_type")
|
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||||
channelSetting := c.GetStringMap("channel_setting")
|
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
||||||
paramOverride := c.GetStringMap("param_override")
|
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||||
tokenKey := c.GetString("token_key")
|
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||||
userId := c.GetInt("id")
|
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||||
// firstResponseTime = time.Now() - 1 second
|
// firstResponseTime = time.Now() - 1 second
|
||||||
|
|
||||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
apiType, _ := common.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
|
||||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||||
isFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: c.GetString("base_url"),
|
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
TokenId: tokenId,
|
TokenId: tokenId,
|
||||||
TokenKey: tokenKey,
|
TokenKey: tokenKey,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
|
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||||
TokenUnlimited: tokenUnlimited,
|
TokenUnlimited: tokenUnlimited,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
FirstResponseTime: startTime.Add(-time.Second),
|
FirstResponseTime: startTime.Add(-time.Second),
|
||||||
OriginModelName: c.GetString("original_model"),
|
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||||
UpstreamModelName: c.GetString("original_model"),
|
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||||
//RecodeModelName: c.GetString("original_model"),
|
//RecodeModelName: c.GetString("original_model"),
|
||||||
IsModelMapped: false,
|
IsModelMapped: false,
|
||||||
ApiType: apiType,
|
ApiType: apiType,
|
||||||
@@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||||
}
|
}
|
||||||
if info.BaseUrl == "" {
|
if info.BaseUrl == "" {
|
||||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
info.BaseUrl = constant.ChannelBaseURLs[channelType]
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == constant.ChannelTypeAzure {
|
||||||
info.ApiVersion = GetAPIVersion(c)
|
info.ApiVersion = GetAPIVersion(c)
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeVertexAi {
|
if info.ChannelType == constant.ChannelTypeVertexAi {
|
||||||
info.ApiVersion = c.GetString("region")
|
info.ApiVersion = c.GetString("region")
|
||||||
}
|
}
|
||||||
if streamSupportedChannels[info.ChannelType] {
|
if streamSupportedChannels[info.ChannelType] {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"one-api/common"
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
|||||||
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
switch channelType {
|
switch channelType {
|
||||||
case common.ChannelTypeOpenAI:
|
case constant.ChannelTypeOpenAI:
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -21,7 +22,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
println("reranker response body: ", string(responseBody))
|
println("reranker response body: ", string(responseBody))
|
||||||
}
|
}
|
||||||
var jinaResp dto.RerankResponse
|
var jinaResp dto.RerankResponse
|
||||||
if info.ChannelType == common.ChannelTypeXinference {
|
if info.ChannelType == constant.ChannelTypeXinference {
|
||||||
var xinRerankResponse xinference.XinRerankResponse
|
var xinRerankResponse xinference.XinRerankResponse
|
||||||
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
APITypeOpenAI = iota
|
|
||||||
APITypeAnthropic
|
|
||||||
APITypePaLM
|
|
||||||
APITypeBaidu
|
|
||||||
APITypeZhipu
|
|
||||||
APITypeAli
|
|
||||||
APITypeXunfei
|
|
||||||
APITypeAIProxyLibrary
|
|
||||||
APITypeTencent
|
|
||||||
APITypeGemini
|
|
||||||
APITypeZhipuV4
|
|
||||||
APITypeOllama
|
|
||||||
APITypePerplexity
|
|
||||||
APITypeAws
|
|
||||||
APITypeCohere
|
|
||||||
APITypeDify
|
|
||||||
APITypeJina
|
|
||||||
APITypeCloudflare
|
|
||||||
APITypeSiliconFlow
|
|
||||||
APITypeVertexAi
|
|
||||||
APITypeMistral
|
|
||||||
APITypeDeepSeek
|
|
||||||
APITypeMokaAI
|
|
||||||
APITypeVolcEngine
|
|
||||||
APITypeBaiduV2
|
|
||||||
APITypeOpenRouter
|
|
||||||
APITypeXinference
|
|
||||||
APITypeXai
|
|
||||||
APITypeCoze
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
|
||||||
)
|
|
||||||
|
|
||||||
func ChannelType2APIType(channelType int) (int, bool) {
|
|
||||||
apiType := -1
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
apiType = APITypeOpenAI
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
apiType = APITypeAnthropic
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
apiType = APITypeBaidu
|
|
||||||
case common.ChannelTypePaLM:
|
|
||||||
apiType = APITypePaLM
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
apiType = APITypeZhipu
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
apiType = APITypeAli
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
apiType = APITypeXunfei
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
apiType = APITypeAIProxyLibrary
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
apiType = APITypeTencent
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
apiType = APITypeGemini
|
|
||||||
case common.ChannelTypeZhipu_v4:
|
|
||||||
apiType = APITypeZhipuV4
|
|
||||||
case common.ChannelTypeOllama:
|
|
||||||
apiType = APITypeOllama
|
|
||||||
case common.ChannelTypePerplexity:
|
|
||||||
apiType = APITypePerplexity
|
|
||||||
case common.ChannelTypeAws:
|
|
||||||
apiType = APITypeAws
|
|
||||||
case common.ChannelTypeCohere:
|
|
||||||
apiType = APITypeCohere
|
|
||||||
case common.ChannelTypeDify:
|
|
||||||
apiType = APITypeDify
|
|
||||||
case common.ChannelTypeJina:
|
|
||||||
apiType = APITypeJina
|
|
||||||
case common.ChannelCloudflare:
|
|
||||||
apiType = APITypeCloudflare
|
|
||||||
case common.ChannelTypeSiliconFlow:
|
|
||||||
apiType = APITypeSiliconFlow
|
|
||||||
case common.ChannelTypeVertexAi:
|
|
||||||
apiType = APITypeVertexAi
|
|
||||||
case common.ChannelTypeMistral:
|
|
||||||
apiType = APITypeMistral
|
|
||||||
case common.ChannelTypeDeepSeek:
|
|
||||||
apiType = APITypeDeepSeek
|
|
||||||
case common.ChannelTypeMokaAI:
|
|
||||||
apiType = APITypeMokaAI
|
|
||||||
case common.ChannelTypeVolcEngine:
|
|
||||||
apiType = APITypeVolcEngine
|
|
||||||
case common.ChannelTypeBaiduV2:
|
|
||||||
apiType = APITypeBaiduV2
|
|
||||||
case common.ChannelTypeOpenRouter:
|
|
||||||
apiType = APITypeOpenRouter
|
|
||||||
case common.ChannelTypeXinference:
|
|
||||||
apiType = APITypeXinference
|
|
||||||
case common.ChannelTypeXai:
|
|
||||||
apiType = APITypeXai
|
|
||||||
case common.ChannelTypeCoze:
|
|
||||||
apiType = APITypeCoze
|
|
||||||
}
|
|
||||||
if apiType == -1 {
|
|
||||||
return APITypeOpenAI, false
|
|
||||||
}
|
|
||||||
return apiType, true
|
|
||||||
}
|
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -17,8 +18,6 @@ import (
|
|||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"one-api/relay/constant"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"one-api/constant"
|
||||||
commonconstant "one-api/constant"
|
commonconstant "one-api/constant"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ali"
|
"one-api/relay/channel/ali"
|
||||||
@@ -32,7 +33,6 @@ import (
|
|||||||
"one-api/relay/channel/xunfei"
|
"one-api/relay/channel/xunfei"
|
||||||
"one-api/relay/channel/zhipu"
|
"one-api/relay/channel/zhipu"
|
||||||
"one-api/relay/channel/zhipu_4v"
|
"one-api/relay/channel/zhipu_4v"
|
||||||
"one-api/relay/constant"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAdaptor(apiType int) channel.Adaptor {
|
func GetAdaptor(apiType int) channel.Adaptor {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
@@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
|||||||
}
|
}
|
||||||
if err.StatusCode == http.StatusForbidden {
|
if err.StatusCode == http.StatusForbidden {
|
||||||
switch channelType {
|
switch channelType {
|
||||||
case common.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/openrouter"
|
"one-api/relay/channel/openrouter"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
|||||||
Stream: claudeRequest.Stream,
|
Stream: claudeRequest.Stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
|
isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
|
||||||
|
|
||||||
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
||||||
if isOpenRouter {
|
if isOpenRouter {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
constant2 "one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
cacheCreationRatio := priceData.CacheCreationRatio
|
cacheCreationRatio := priceData.CacheCreationRatio
|
||||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||||
|
|
||||||
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
|
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
promptTokens -= cacheTokens
|
promptTokens -= cacheTokens
|
||||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
||||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||||
@@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
userSetting := relayInfo.UserSetting
|
userSetting := relayInfo.UserSetting
|
||||||
threshold := common.QuotaRemindThreshold
|
threshold := common.QuotaRemindThreshold
|
||||||
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
|
||||||
threshold = int(userCustomThreshold.(float64))
|
threshold = int(userCustomThreshold.(float64))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
|
|||||||
if !constant.GetMediaToken {
|
if !constant.GetMediaToken {
|
||||||
return 3 * baseTokens, nil
|
return 3 * baseTokens, nil
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
|
||||||
return 3 * baseTokens, nil
|
return 3 * baseTokens, nil
|
||||||
}
|
}
|
||||||
var config image.Config
|
var config image.Config
|
||||||
|
|||||||
42
types/set.go
Normal file
42
types/set.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Set[T comparable] struct {
|
||||||
|
items map[T]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSet 创建并返回一个新的 Set
|
||||||
|
func NewSet[T comparable]() *Set[T] {
|
||||||
|
return &Set[T]{
|
||||||
|
items: make(map[T]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Set[T]) Add(item T) {
|
||||||
|
s.items[item] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 从 Set 中移除一个元素
|
||||||
|
func (s *Set[T]) Remove(item T) {
|
||||||
|
delete(s.items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains 检查 Set 是否包含某个元素
|
||||||
|
func (s *Set[T]) Contains(item T) bool {
|
||||||
|
_, exists := s.items[item]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len 返回 Set 中元素的数量
|
||||||
|
func (s *Set[T]) Len() int {
|
||||||
|
return len(s.items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Items 返回 Set 中所有元素组成的切片
|
||||||
|
// 注意:由于 map 的无序性,返回的切片元素顺序是随机的
|
||||||
|
func (s *Set[T]) Items() []T {
|
||||||
|
items := make([]T, 0, s.Len())
|
||||||
|
for item := range s.items {
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
return items
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user