refactor: Introduce pre-consume quota and unify relay handlers
This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,17 +35,6 @@ type ClaudeConvertInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAIResponses = "openai_responses"
|
||||
RelayFormatOpenAIAudio = "openai_audio"
|
||||
RelayFormatOpenAIImage = "openai_image"
|
||||
RelayFormatRerank = "rerank"
|
||||
RelayFormatEmbedding = "embedding"
|
||||
)
|
||||
|
||||
type RerankerInfo struct {
|
||||
Documents []any
|
||||
ReturnDocuments bool
|
||||
@@ -59,61 +50,103 @@ type ResponsesUsageInfo struct {
|
||||
BuiltInTools map[string]*BuildInToolInfo
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
type ChannelMeta struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
ChannelIsMultiKey bool // 是否多密钥
|
||||
ChannelMultiKeyIndex int // 多密钥索引
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
ChannelIsMultiKey bool
|
||||
ChannelMultiKeyIndex int
|
||||
ChannelBaseUrl string
|
||||
ApiType int
|
||||
ApiVersion string
|
||||
ApiKey string
|
||||
Organization string
|
||||
ChannelCreateTime int64
|
||||
ParamOverride map[string]interface{}
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ChannelOtherSettings dto.ChannelOtherSettings
|
||||
UpstreamModelName string
|
||||
IsModelMapped bool
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
//SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsGeminiBatchEmbedding bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ChannelOtherSettings dto.ChannelOtherSettings
|
||||
ParamOverride map[string]interface{}
|
||||
UserSetting dto.UserSetting
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat string
|
||||
SendResponseCount int
|
||||
ChannelCreateTime int64
|
||||
RequestURLPath string
|
||||
PromptTokens int
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
UserSetting dto.UserSetting
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat types.RelayFormat
|
||||
SendResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
Request dto.Request
|
||||
|
||||
ThinkingContentInfo
|
||||
*ClaudeConvertInfo
|
||||
*RerankerInfo
|
||||
*ResponsesUsageInfo
|
||||
*ChannelMeta
|
||||
}
|
||||
|
||||
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
channelMeta := &ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
|
||||
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||
ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
IsModelMapped: false,
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
channelMeta.ChannelSetting = channelSetting
|
||||
}
|
||||
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
if ok {
|
||||
channelMeta.ChannelOtherSettings = channelOtherSettings
|
||||
}
|
||||
info.ChannelMeta = channelMeta
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
@@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info := genBaseRelayInfo(c, nil)
|
||||
info.RelayFormat = types.RelayFormatOpenAIRealtime
|
||||
info.ClientWs = ws
|
||||
info.InputAudioFormat = "pcm16"
|
||||
info.OutputAudioFormat = "pcm16"
|
||||
@@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatClaude
|
||||
func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatClaude
|
||||
info.ShouldIncludeUsage = false
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
@@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
info.RelayFormat = RelayFormatRerank
|
||||
info.RelayFormat = types.RelayFormatRerank
|
||||
info.RerankerInfo = &RerankerInfo{
|
||||
Documents: req.Documents,
|
||||
ReturnDocuments: req.GetReturnDocuments(),
|
||||
Documents: request.Documents,
|
||||
ReturnDocuments: request.GetReturnDocuments(),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIAudio
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAIAudio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatEmbedding
|
||||
func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatEmbedding
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeResponses
|
||||
info.RelayFormat = RelayFormatOpenAIResponses
|
||||
info.RelayFormat = types.RelayFormatOpenAIResponses
|
||||
|
||||
info.SupportStreamOptions = false
|
||||
|
||||
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
for _, tool := range req.Tools {
|
||||
if len(request.Tools) > 0 {
|
||||
for _, tool := range request.Tools {
|
||||
toolType := common.Interface2String(tool["type"])
|
||||
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
|
||||
ToolName: toolType,
|
||||
@@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
|
||||
}
|
||||
}
|
||||
}
|
||||
info.IsStream = req.Stream
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatGemini
|
||||
func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatGemini
|
||||
info.ShouldIncludeUsage = false
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIImage
|
||||
func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAIImage
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAI
|
||||
return info
|
||||
}
|
||||
|
||||
func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
|
||||
//channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
Request: request,
|
||||
|
||||
UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
|
||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
||||
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
IsStream: request.IsStream(c),
|
||||
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
RelayFormat: RelayFormatOpenAI,
|
||||
ThinkingContentInfo: ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
|
||||
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
|
||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = constant.ChannelBaseURLs[channelType]
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
info.ChannelSetting = channelSetting
|
||||
}
|
||||
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
if ok {
|
||||
info.ChannelOtherSettings = channelOtherSettings
|
||||
}
|
||||
|
||||
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
|
||||
if ok {
|
||||
@@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
return GenRelayInfoOpenAI(c, request), nil
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
return GenRelayInfoOpenAIAudio(c, request), nil
|
||||
case types.RelayFormatOpenAIImage:
|
||||
return GenRelayInfoImage(c, request), nil
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
return GenRelayInfoWs(c, ws), nil
|
||||
case types.RelayFormatClaude:
|
||||
return GenRelayInfoClaude(c, request), nil
|
||||
case types.RelayFormatRerank:
|
||||
if request, ok := request.(*dto.RerankRequest); ok {
|
||||
return GenRelayInfoRerank(c, request), nil
|
||||
}
|
||||
return nil, errors.New("request is not a RerankRequest")
|
||||
case types.RelayFormatGemini:
|
||||
return GenRelayInfoGemini(c, request), nil
|
||||
case types.RelayFormatEmbedding:
|
||||
return GenRelayInfoEmbedding(c, request), nil
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
return GenRelayInfoResponses(c, request), nil
|
||||
}
|
||||
return nil, errors.New("request is not a OpenAIResponsesRequest")
|
||||
default:
|
||||
return nil, errors.New("invalid relay format")
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
|
||||
Reference in New Issue
Block a user