494 lines
16 KiB
Go
494 lines
16 KiB
Go
package common
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relayconstant "one-api/relay/constant"
|
|
"one-api/types"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type ThinkingContentInfo struct {
|
|
IsFirstThinkingContent bool
|
|
SendLastThinkingContent bool
|
|
HasSentThinkingContent bool
|
|
}
|
|
|
|
const (
|
|
LastMessageTypeNone = "none"
|
|
LastMessageTypeText = "text"
|
|
LastMessageTypeTools = "tools"
|
|
LastMessageTypeThinking = "thinking"
|
|
)
|
|
|
|
type ClaudeConvertInfo struct {
|
|
LastMessagesType string
|
|
Index int
|
|
Usage *dto.Usage
|
|
FinishReason string
|
|
Done bool
|
|
}
|
|
|
|
type RerankerInfo struct {
|
|
Documents []any
|
|
ReturnDocuments bool
|
|
}
|
|
|
|
type BuildInToolInfo struct {
|
|
ToolName string
|
|
CallCount int
|
|
SearchContextSize string
|
|
}
|
|
|
|
type ResponsesUsageInfo struct {
|
|
BuiltInTools map[string]*BuildInToolInfo
|
|
}
|
|
|
|
type ChannelMeta struct {
|
|
ChannelType int
|
|
ChannelId int
|
|
ChannelIsMultiKey bool
|
|
ChannelMultiKeyIndex int
|
|
ChannelBaseUrl string
|
|
ApiType int
|
|
ApiVersion string
|
|
ApiKey string
|
|
Organization string
|
|
ChannelCreateTime int64
|
|
ParamOverride map[string]interface{}
|
|
ChannelSetting dto.ChannelSettings
|
|
ChannelOtherSettings dto.ChannelOtherSettings
|
|
UpstreamModelName string
|
|
IsModelMapped bool
|
|
SupportStreamOptions bool // 是否支持流式选项
|
|
}
|
|
|
|
type RelayInfo struct {
|
|
TokenId int
|
|
TokenKey string
|
|
UserId int
|
|
UsingGroup string // 使用的分组
|
|
UserGroup string // 用户所在分组
|
|
TokenUnlimited bool
|
|
StartTime time.Time
|
|
FirstResponseTime time.Time
|
|
isFirstResponse bool
|
|
//SendLastReasoningResponse bool
|
|
IsStream bool
|
|
IsGeminiBatchEmbedding bool
|
|
IsPlayground bool
|
|
UsePrice bool
|
|
RelayMode int
|
|
OriginModelName string
|
|
RequestURLPath string
|
|
PromptTokens int
|
|
ShouldIncludeUsage bool
|
|
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
|
ClientWs *websocket.Conn
|
|
TargetWs *websocket.Conn
|
|
InputAudioFormat string
|
|
OutputAudioFormat string
|
|
RealtimeTools []dto.RealTimeTool
|
|
IsFirstRequest bool
|
|
AudioUsage bool
|
|
ReasoningEffort string
|
|
UserSetting dto.UserSetting
|
|
UserEmail string
|
|
UserQuota int
|
|
RelayFormat types.RelayFormat
|
|
SendResponseCount int
|
|
FinalPreConsumedQuota int // 最终预消耗的配额
|
|
|
|
PriceData types.PriceData
|
|
|
|
Request dto.Request
|
|
|
|
ThinkingContentInfo
|
|
*ClaudeConvertInfo
|
|
*RerankerInfo
|
|
*ResponsesUsageInfo
|
|
*ChannelMeta
|
|
}
|
|
|
|
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
|
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
|
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
|
apiType, _ := common.ChannelType2APIType(channelType)
|
|
channelMeta := &ChannelMeta{
|
|
ChannelType: channelType,
|
|
ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
|
|
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
|
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
|
ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
|
ApiType: apiType,
|
|
ApiVersion: c.GetString("api_version"),
|
|
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
|
Organization: c.GetString("channel_organization"),
|
|
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
|
ParamOverride: paramOverride,
|
|
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
|
IsModelMapped: false,
|
|
SupportStreamOptions: false,
|
|
}
|
|
|
|
if channelType == constant.ChannelTypeAzure {
|
|
channelMeta.ApiVersion = GetAPIVersion(c)
|
|
}
|
|
if channelType == constant.ChannelTypeVertexAi {
|
|
channelMeta.ApiVersion = c.GetString("region")
|
|
}
|
|
|
|
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
|
if ok {
|
|
channelMeta.ChannelSetting = channelSetting
|
|
}
|
|
|
|
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
|
if ok {
|
|
channelMeta.ChannelOtherSettings = channelOtherSettings
|
|
}
|
|
|
|
if streamSupportedChannels[channelMeta.ChannelType] {
|
|
channelMeta.SupportStreamOptions = true
|
|
}
|
|
info.ChannelMeta = channelMeta
|
|
}
|
|
|
|
func (info *RelayInfo) ToString() string {
|
|
if info == nil {
|
|
return "RelayInfo<nil>"
|
|
}
|
|
|
|
// Basic info
|
|
b := &strings.Builder{}
|
|
fmt.Fprintf(b, "RelayInfo{ ")
|
|
fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
|
|
fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
|
|
fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
|
|
fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
|
|
fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
|
|
fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
|
|
fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
|
|
fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
|
|
fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
|
|
fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
|
|
fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
|
|
|
|
// User & token info (mask secrets)
|
|
fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
|
|
info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
|
|
fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
|
|
|
|
// Time info
|
|
latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
|
|
fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
|
|
info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
|
|
|
|
// Audio / realtime
|
|
if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
|
|
fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
|
|
info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
|
|
}
|
|
|
|
// Reasoning
|
|
if info.ReasoningEffort != "" {
|
|
fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
|
|
}
|
|
|
|
// Price data (non-sensitive)
|
|
if info.PriceData.UsePrice {
|
|
fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
|
|
}
|
|
|
|
// Channel metadata (mask ApiKey)
|
|
if info.ChannelMeta != nil {
|
|
cm := info.ChannelMeta
|
|
fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
|
|
cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
|
|
}
|
|
|
|
// Responses usage info (non-sensitive)
|
|
if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
|
|
fmt.Fprintf(b, "ResponsesTools{ ")
|
|
first := true
|
|
for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
|
|
if !first {
|
|
fmt.Fprintf(b, ", ")
|
|
}
|
|
first = false
|
|
if tool != nil {
|
|
fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
|
|
} else {
|
|
fmt.Fprintf(b, "%s: calls=0", name)
|
|
}
|
|
}
|
|
fmt.Fprintf(b, " }, ")
|
|
}
|
|
|
|
fmt.Fprintf(b, "}")
|
|
return b.String()
|
|
}
|
|
|
|
// 定义支持流式选项的通道类型
|
|
var streamSupportedChannels = map[int]bool{
|
|
constant.ChannelTypeOpenAI: true,
|
|
constant.ChannelTypeAnthropic: true,
|
|
constant.ChannelTypeAws: true,
|
|
constant.ChannelTypeGemini: true,
|
|
constant.ChannelCloudflare: true,
|
|
constant.ChannelTypeAzure: true,
|
|
constant.ChannelTypeVolcEngine: true,
|
|
constant.ChannelTypeOllama: true,
|
|
constant.ChannelTypeXai: true,
|
|
constant.ChannelTypeDeepSeek: true,
|
|
constant.ChannelTypeBaiduV2: true,
|
|
}
|
|
|
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
|
info := genBaseRelayInfo(c, nil)
|
|
info.RelayFormat = types.RelayFormatOpenAIRealtime
|
|
info.ClientWs = ws
|
|
info.InputAudioFormat = "pcm16"
|
|
info.OutputAudioFormat = "pcm16"
|
|
info.IsFirstRequest = true
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatClaude
|
|
info.ShouldIncludeUsage = false
|
|
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
|
LastMessagesType: LastMessageTypeNone,
|
|
}
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayMode = relayconstant.RelayModeRerank
|
|
info.RelayFormat = types.RelayFormatRerank
|
|
info.RerankerInfo = &RerankerInfo{
|
|
Documents: request.Documents,
|
|
ReturnDocuments: request.GetReturnDocuments(),
|
|
}
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatOpenAIAudio
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatEmbedding
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayMode = relayconstant.RelayModeResponses
|
|
info.RelayFormat = types.RelayFormatOpenAIResponses
|
|
|
|
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
|
BuiltInTools: make(map[string]*BuildInToolInfo),
|
|
}
|
|
if len(request.Tools) > 0 {
|
|
for _, tool := range request.Tools {
|
|
toolType := common.Interface2String(tool["type"])
|
|
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
|
|
ToolName: toolType,
|
|
CallCount: 0,
|
|
}
|
|
switch toolType {
|
|
case dto.BuildInToolWebSearchPreview:
|
|
searchContextSize := common.Interface2String(tool["search_context_size"])
|
|
if searchContextSize == "" {
|
|
searchContextSize = "medium"
|
|
}
|
|
info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
|
|
}
|
|
}
|
|
}
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatGemini
|
|
info.ShouldIncludeUsage = false
|
|
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatOpenAIImage
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
|
|
info := genBaseRelayInfo(c, request)
|
|
info.RelayFormat = types.RelayFormatOpenAI
|
|
return info
|
|
}
|
|
|
|
func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
|
|
|
//channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
|
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
|
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
|
|
|
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
|
if startTime.IsZero() {
|
|
startTime = time.Now()
|
|
}
|
|
|
|
isStream := false
|
|
|
|
if request != nil {
|
|
isStream = request.IsStream(c)
|
|
}
|
|
|
|
// firstResponseTime = time.Now() - 1 second
|
|
|
|
info := &RelayInfo{
|
|
Request: request,
|
|
|
|
UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
|
|
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
|
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
|
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
|
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
|
|
|
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
|
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
|
|
|
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
|
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
|
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
|
|
|
isFirstResponse: true,
|
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
|
RequestURLPath: c.Request.URL.String(),
|
|
IsStream: isStream,
|
|
|
|
StartTime: startTime,
|
|
FirstResponseTime: startTime.Add(-time.Second),
|
|
ThinkingContentInfo: ThinkingContentInfo{
|
|
IsFirstThinkingContent: true,
|
|
SendLastThinkingContent: false,
|
|
},
|
|
}
|
|
|
|
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
|
info.IsPlayground = true
|
|
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
|
|
info.RequestURLPath = "/v1" + info.RequestURLPath
|
|
}
|
|
|
|
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
|
|
if ok {
|
|
info.UserSetting = userSetting
|
|
}
|
|
|
|
return info
|
|
}
|
|
|
|
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
|
|
switch relayFormat {
|
|
case types.RelayFormatOpenAI:
|
|
return GenRelayInfoOpenAI(c, request), nil
|
|
case types.RelayFormatOpenAIAudio:
|
|
return GenRelayInfoOpenAIAudio(c, request), nil
|
|
case types.RelayFormatOpenAIImage:
|
|
return GenRelayInfoImage(c, request), nil
|
|
case types.RelayFormatOpenAIRealtime:
|
|
return GenRelayInfoWs(c, ws), nil
|
|
case types.RelayFormatClaude:
|
|
return GenRelayInfoClaude(c, request), nil
|
|
case types.RelayFormatRerank:
|
|
if request, ok := request.(*dto.RerankRequest); ok {
|
|
return GenRelayInfoRerank(c, request), nil
|
|
}
|
|
return nil, errors.New("request is not a RerankRequest")
|
|
case types.RelayFormatGemini:
|
|
return GenRelayInfoGemini(c, request), nil
|
|
case types.RelayFormatEmbedding:
|
|
return GenRelayInfoEmbedding(c, request), nil
|
|
case types.RelayFormatOpenAIResponses:
|
|
if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
|
return GenRelayInfoResponses(c, request), nil
|
|
}
|
|
return nil, errors.New("request is not a OpenAIResponsesRequest")
|
|
case types.RelayFormatTask:
|
|
return genBaseRelayInfo(c, nil), nil
|
|
case types.RelayFormatMjProxy:
|
|
return genBaseRelayInfo(c, nil), nil
|
|
default:
|
|
return nil, errors.New("invalid relay format")
|
|
}
|
|
}
|
|
|
|
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
|
info.PromptTokens = promptTokens
|
|
}
|
|
|
|
func (info *RelayInfo) SetFirstResponseTime() {
|
|
if info.isFirstResponse {
|
|
info.FirstResponseTime = time.Now()
|
|
info.isFirstResponse = false
|
|
}
|
|
}
|
|
|
|
func (info *RelayInfo) HasSendResponse() bool {
|
|
return info.FirstResponseTime.After(info.StartTime)
|
|
}
|
|
|
|
type TaskRelayInfo struct {
|
|
*RelayInfo
|
|
Action string
|
|
OriginTaskID string
|
|
|
|
ConsumeQuota bool
|
|
}
|
|
|
|
func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
|
|
relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
info := &TaskRelayInfo{
|
|
RelayInfo: relayInfo,
|
|
}
|
|
return info, nil
|
|
}
|
|
|
|
type TaskSubmitReq struct {
|
|
Prompt string `json:"prompt"`
|
|
Model string `json:"model,omitempty"`
|
|
Mode string `json:"mode,omitempty"`
|
|
Image string `json:"image,omitempty"`
|
|
Size string `json:"size,omitempty"`
|
|
Duration int `json:"duration,omitempty"`
|
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
}
|
|
|
|
type TaskInfo struct {
|
|
Code int `json:"code"`
|
|
TaskID string `json:"task_id"`
|
|
Status string `json:"status"`
|
|
Reason string `json:"reason,omitempty"`
|
|
Url string `json:"url,omitempty"`
|
|
Progress string `json:"progress,omitempty"`
|
|
}
|