This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
249 lines
7.0 KiB
Go
249 lines
7.0 KiB
Go
package cohere
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"one-api/types"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|
cohereReq := CohereRequest{
|
|
Model: textRequest.Model,
|
|
ChatHistory: []ChatHistory{},
|
|
Message: "",
|
|
Stream: textRequest.Stream,
|
|
MaxTokens: textRequest.GetMaxTokens(),
|
|
}
|
|
if common.CohereSafetySetting != "NONE" {
|
|
cohereReq.SafetyMode = common.CohereSafetySetting
|
|
}
|
|
if cohereReq.MaxTokens == 0 {
|
|
cohereReq.MaxTokens = 4000
|
|
}
|
|
for _, msg := range textRequest.Messages {
|
|
if msg.Role == "user" {
|
|
cohereReq.Message = msg.StringContent()
|
|
} else {
|
|
var role string
|
|
if msg.Role == "assistant" {
|
|
role = "CHATBOT"
|
|
} else if msg.Role == "system" {
|
|
role = "SYSTEM"
|
|
} else {
|
|
role = "USER"
|
|
}
|
|
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
|
Role: role,
|
|
Message: msg.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
|
|
return &cohereReq
|
|
}
|
|
|
|
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
|
if rerankRequest.TopN == 0 {
|
|
rerankRequest.TopN = 1
|
|
}
|
|
cohereReq := CohereRerankRequest{
|
|
Query: rerankRequest.Query,
|
|
Documents: rerankRequest.Documents,
|
|
Model: rerankRequest.Model,
|
|
TopN: rerankRequest.TopN,
|
|
ReturnDocuments: true,
|
|
}
|
|
return &cohereReq
|
|
}
|
|
|
|
func stopReasonCohere2OpenAI(reason string) string {
|
|
switch reason {
|
|
case "COMPLETE":
|
|
return "stop"
|
|
case "MAX_TOKENS":
|
|
return "max_tokens"
|
|
default:
|
|
return reason
|
|
}
|
|
}
|
|
|
|
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
responseId := helper.GetResponseID(c)
|
|
createdTime := common.GetTimestamp()
|
|
usage := &dto.Usage{}
|
|
responseText := ""
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
return i + 1, data[0:i], nil
|
|
}
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
return 0, nil, nil
|
|
})
|
|
dataChan := make(chan string)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
for scanner.Scan() {
|
|
data := scanner.Text()
|
|
dataChan <- data
|
|
}
|
|
stopChan <- true
|
|
}()
|
|
helper.SetEventStreamHeaders(c)
|
|
isFirst := true
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case data := <-dataChan:
|
|
if isFirst {
|
|
isFirst = false
|
|
info.FirstResponseTime = time.Now()
|
|
}
|
|
data = strings.TrimSuffix(data, "\r")
|
|
var cohereResp CohereResponse
|
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
|
if err != nil {
|
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
var openaiResp dto.ChatCompletionsStreamResponse
|
|
openaiResp.Id = responseId
|
|
openaiResp.Created = createdTime
|
|
openaiResp.Object = "chat.completion.chunk"
|
|
openaiResp.Model = info.UpstreamModelName
|
|
if cohereResp.IsFinished {
|
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
{
|
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
|
Index: 0,
|
|
FinishReason: &finishReason,
|
|
},
|
|
}
|
|
if cohereResp.Response != nil {
|
|
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
|
}
|
|
} else {
|
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
{
|
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
|
Role: "assistant",
|
|
Content: &cohereResp.Text,
|
|
},
|
|
Index: 0,
|
|
},
|
|
}
|
|
responseText += cohereResp.Text
|
|
}
|
|
jsonStr, err := json.Marshal(openaiResp)
|
|
if err != nil {
|
|
common.SysLog("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
if usage.PromptTokens == 0 {
|
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
}
|
|
return usage, nil
|
|
}
|
|
|
|
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
createdTime := common.GetTimestamp()
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
service.CloseResponseBodyGracefully(resp)
|
|
var cohereResp CohereResponseResult
|
|
err = json.Unmarshal(responseBody, &cohereResp)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
usage := dto.Usage{}
|
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
|
|
var openaiResp dto.TextResponse
|
|
openaiResp.Id = cohereResp.ResponseId
|
|
openaiResp.Created = createdTime
|
|
openaiResp.Object = "chat.completion"
|
|
openaiResp.Model = info.UpstreamModelName
|
|
openaiResp.Usage = usage
|
|
|
|
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
|
{
|
|
Index: 0,
|
|
Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
|
|
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
|
},
|
|
}
|
|
|
|
jsonResponse, err := json.Marshal(openaiResp)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, _ = c.Writer.Write(jsonResponse)
|
|
return &usage, nil
|
|
}
|
|
|
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
service.CloseResponseBodyGracefully(resp)
|
|
var cohereResp CohereRerankResponseResult
|
|
err = json.Unmarshal(responseBody, &cohereResp)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
usage := dto.Usage{}
|
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
|
usage.PromptTokens = info.PromptTokens
|
|
usage.CompletionTokens = 0
|
|
usage.TotalTokens = info.PromptTokens
|
|
} else {
|
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
}
|
|
|
|
var rerankResp dto.RerankResponse
|
|
rerankResp.Results = cohereResp.Results
|
|
rerankResp.Usage = usage
|
|
|
|
jsonResponse, err := json.Marshal(rerankResp)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return &usage, nil
|
|
}
|