feat: refactor environment variable initialization and introduce new constant types for API and context keys

This commit is contained in:
CaIon
2025-07-03 13:10:25 +08:00
parent 34aca14858
commit 7e298f8ad1
43 changed files with 749 additions and 576 deletions

View File

@@ -9,8 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"one-api/common"
constant2 "one-api/constant"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
@@ -21,7 +20,7 @@ import (
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"path/filepath"
"strings"
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = "wss://" + baseUrl
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
switch info.ChannelType {
case common.ChannelTypeAzure:
case constant.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
apiVersion = constant.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
task := strings.TrimPrefix(requestURL, "/v1/")
// 特殊处理 responses API
if info.RelayMode == constant.RelayModeResponses {
if info.RelayMode == relayconstant.RelayModeResponses {
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case common.ChannelTypeCustom:
case constant.ChannelTypeCustom:
url := info.BaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure {
if info.ChannelType == constant.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == common.ChannelTypeOpenRouter {
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("X-Title", "New API")
}
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
request.StreamOptions = nil
}
if info.ChannelType == common.ChannelTypeOpenRouter {
if info.ChannelType == constant.ChannelTypeOpenRouter {
if len(request.Usage) == 0 {
request.Usage = json.RawMessage(`{"include":true}`)
}
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesEdits:
case relayconstant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription ||
info.RelayMode == constant.RelayModeAudioTranslation ||
info.RelayMode == constant.RelayModeImagesEdits {
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
info.RelayMode == relayconstant.RelayModeImagesEdits {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
} else if info.RelayMode == relayconstant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRealtime:
case relayconstant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech:
case relayconstant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
case relayconstant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err, usage = OpenaiHandlerWithUsage(c, resp, info)
case constant.RelayModeRerank:
case relayconstant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, info, resp)
case constant.RelayModeResponses:
case relayconstant.RelayModeResponses:
if info.IsStream {
err, usage = OaiResponsesStreamHandler(c, resp, info)
} else {
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case common.ChannelType360:
case constant.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
case constant.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeLingYiWanWu:
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
case constant.ChannelTypeXinference:
return xinference.ModelList
case common.ChannelTypeOpenRouter:
case constant.ChannelTypeOpenRouter:
return openrouter.ModelList
default:
return ModelList
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
case common.ChannelType360:
case constant.ChannelType360:
return ai360.ChannelName
case common.ChannelTypeMoonshot:
case constant.ChannelTypeMoonshot:
return moonshot.ChannelName
case common.ChannelTypeLingYiWanWu:
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
case constant.ChannelTypeXinference:
return xinference.ChannelName
case common.ChannelTypeOpenRouter:
case constant.ChannelTypeOpenRouter:
return openrouter.ChannelName
default:
return ChannelName

View File

@@ -168,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
if info.ChannelType == constant.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}