This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
268 lines
8.8 KiB
Go
268 lines
8.8 KiB
Go
package openai
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"github.com/samber/lo"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
"one-api/logger"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/types"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// 辅助函数
|
||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||
info.SendResponseCount++
|
||
|
||
switch info.RelayFormat {
|
||
case types.RelayFormatOpenAI:
|
||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||
case types.RelayFormatClaude:
|
||
return handleClaudeFormat(c, data, info)
|
||
case types.RelayFormatGemini:
|
||
return handleGeminiFormat(c, data, info)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||
return err
|
||
}
|
||
|
||
if streamResponse.Usage != nil {
|
||
info.ClaudeConvertInfo.Usage = streamResponse.Usage
|
||
}
|
||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||
for _, resp := range claudeResponses {
|
||
helper.ClaudeData(c, *resp)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||
return err
|
||
}
|
||
|
||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||
|
||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||
if geminiResponse == nil {
|
||
return nil
|
||
}
|
||
|
||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||
if err != nil {
|
||
logger.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||
return err
|
||
}
|
||
|
||
// send gemini format response
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
} else {
|
||
return errors.New("streaming error: flusher not found")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||
if choice.Delta.ToolCalls != nil {
|
||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||
*toolCount = len(choice.Delta.ToolCalls)
|
||
}
|
||
for _, tool := range choice.Delta.ToolCalls {
|
||
responseTextBuilder.WriteString(tool.Function.Name)
|
||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||
|
||
switch relayMode {
|
||
case relayconstant.RelayModeChatCompletions:
|
||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||
case relayconstant.RelayModeCompletions:
|
||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||
// 一次性解析失败,逐个解析
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
for _, item := range streamItems {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||
return err
|
||
}
|
||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||
common.SysLog("error processing stream response: " + err.Error())
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 批量处理所有响应
|
||
for _, streamResponse := range streamResponses {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||
if choice.Delta.ToolCalls != nil {
|
||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||
*toolCount = len(choice.Delta.ToolCalls)
|
||
}
|
||
for _, tool := range choice.Delta.ToolCalls {
|
||
responseTextBuilder.WriteString(tool.Function.Name)
|
||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||
var streamResponses []dto.CompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||
// 一次性解析失败,逐个解析
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
for _, item := range streamItems {
|
||
var streamResponse dto.CompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||
continue
|
||
}
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Text)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 批量处理所有响应
|
||
for _, streamResponse := range streamResponses {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Text)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||
systemFingerprint *string, model *string, usage **dto.Usage,
|
||
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||
shouldSendLastResp *bool) error {
|
||
|
||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||
return err
|
||
}
|
||
|
||
*responseId = lastStreamResponse.Id
|
||
*createAt = lastStreamResponse.Created
|
||
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||
*model = lastStreamResponse.Model
|
||
|
||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||
*containStreamUsage = true
|
||
*usage = lastStreamResponse.Usage
|
||
if !info.ShouldIncludeUsage {
|
||
*shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
|
||
return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
|
||
})
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||
responseId string, createAt int64, model string, systemFingerprint string,
|
||
usage *dto.Usage, containStreamUsage bool) {
|
||
|
||
switch info.RelayFormat {
|
||
case types.RelayFormatOpenAI:
|
||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||
response.SetSystemFingerprint(systemFingerprint)
|
||
helper.ObjectData(c, response)
|
||
}
|
||
helper.Done(c)
|
||
|
||
case types.RelayFormatClaude:
|
||
info.ClaudeConvertInfo.Done = true
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
info.ClaudeConvertInfo.Usage = usage
|
||
|
||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||
for _, resp := range claudeResponses {
|
||
_ = helper.ClaudeData(c, *resp)
|
||
}
|
||
|
||
case types.RelayFormatGemini:
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||
// 暂不知是否有程序会不兼容。
|
||
|
||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||
|
||
// openai 流响应开头的空数据
|
||
if geminiResponse == nil {
|
||
return
|
||
}
|
||
|
||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||
if err != nil {
|
||
common.SysLog("error marshalling gemini response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
// 发送最终的 Gemini 响应
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
|
||
if data == "" {
|
||
return
|
||
}
|
||
helper.ResponseChunkData(c, streamResponse, data)
|
||
}
|