diff --git a/model/ability.go b/model/ability.go index f36ff764..6dd8d8a6 100644 --- a/model/ability.go +++ b/model/ability.go @@ -136,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, } } } else { - return nil, errors.New("channel not found") + return nil, nil } err = DB.First(&channel, "id = ?", channel.Id).Error return &channel, err diff --git a/model/channel_cache.go b/model/channel_cache.go index d18e9c89..45069ba0 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -130,7 +130,7 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, channels := group2model2channels[group][model] if len(channels) == 0 { - return nil, errors.New("channel not found") + return nil, nil } if len(channels) == 1 { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 71eb9ba4..2e31ec55 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/setting/model_setting" @@ -21,10 +22,13 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req) + if err != nil { + return nil, err + } + return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest)) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d19ee1ae..7e57bdac 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -736,7 +737,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C choice := dto.ChatCompletionsStreamResponseChoice{ Index: int(candidate.Index), Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ - Role: "assistant", + //Role: "assistant", }, } var texts []string @@ -798,6 +799,27 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C return &response, isStop, hasImage } +func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error { + streamData, err := common.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal stream response: %w", err) + } + err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) + if err != nil { + return fmt.Errorf("failed to handle stream format: %w", err) + } + return nil +} + +func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error { + streamData, err := common.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal stream response: %w", err) + } + openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, info.ShouldIncludeUsage) + return nil +} + func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { // responseText := "" id := helper.GetResponseID(c) @@ -805,6 +827,8 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * var usage = &dto.Usage{} var imageCount int + respCount := 0 + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) @@ -833,18 +857,31 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * } } } - err = helper.ObjectData(c, response) + + if respCount == 0 { + // send first response + err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)) + if err != nil { + common.LogError(c, err.Error()) + } + } + + err = handleStream(c, info, response) if err != nil { common.LogError(c, err.Error()) } if isStop { - response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) - helper.ObjectData(c, response) + _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)) } + respCount++ return true }) - var response *dto.ChatCompletionsStreamResponse + if respCount == 0 { + // 空补全,报错不计费 + // empty response, throw an error + return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError) + } if imageCount != 0 { if usage.CompletionTokens == 0 { @@ -855,14 +892,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - if info.ShouldIncludeUsage { - response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) - err := helper.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) - } + response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) + err := handleFinalStream(c, info, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) } - helper.Done(c) + //if info.RelayFormat == relaycommon.RelayFormatOpenAI { + // helper.Done(c) + //} //resp.Body.Close() return usage, nil } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 7fee505a..1681c9ff 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -14,7 +14,7 @@ import ( ) // 辅助函数 -func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { +func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { info.SendResponseCount++ switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: @@ -158,7 +158,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int return nil } -func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, +func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, responseId string, createAt int64, model string, systemFingerprint string, usage *dto.Usage, containStreamUsage bool) { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d739ea19..82bd2d26 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -123,24 +123,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var toolCount int var usage = &dto.Usage{} var streamItems []string // store stream items - var forceFormat bool - var thinkToContent bool - - if info.ChannelSetting.ForceFormat { - forceFormat = true - } - - if info.ChannelSetting.ThinkingToContent { - thinkToContent = true - } - - var ( - lastStreamData string - ) + var lastStreamData string helper.StreamScannerHandler(c, resp, info, func(data string) bool { if lastStreamData != "" { - err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) + err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { common.SysError("error handling stream format: " + err.Error()) } @@ -161,7 +148,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if info.RelayFormat == relaycommon.RelayFormatOpenAI { if shouldSendLastResp { - _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) } } @@ -180,7 +167,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } } } - handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) + HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) return usage, nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index 5d23b512..c8edb798 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -139,6 +139,24 @@ func GetLocalRealtimeID(c *gin.Context) string { return fmt.Sprintf("evt_%s", logID) } +func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse { + return &dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: createAt, + Model: model, + SystemFingerprint: systemFingerprint, + Choices: []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ + Role: "assistant", + Content: common.GetPointer(""), + }, + }, + }, + } +} + func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { return &dto.ChatCompletionsStreamResponse{ Id: id, diff --git a/types/error.go b/types/error.go index fa07c231..c94bd001 100644 --- a/types/error.go +++ b/types/error.go @@ -63,6 +63,7 @@ const ( ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" ErrorCodeBadResponse ErrorCode = "bad_response" ErrorCodeBadResponseBody ErrorCode = "bad_response_body" + ErrorCodeEmptyResponse ErrorCode = "empty_response" // sql error ErrorCodeQueryDataError ErrorCode = "query_data_error"