117 lines
3.8 KiB
Go
117 lines
3.8 KiB
Go
package openai
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
"one-api/logger"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"one-api/types"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
defer service.CloseResponseBodyGracefully(resp)
|
|
|
|
// read response body
|
|
var responsesResponse dto.OpenAIResponsesResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
|
}
|
|
err = common.Unmarshal(responseBody, &responsesResponse)
|
|
if err != nil {
|
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
}
|
|
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
|
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
|
}
|
|
|
|
// 写入新的 response body
|
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
|
|
|
// compute usage
|
|
usage := dto.Usage{}
|
|
if responsesResponse.Usage != nil {
|
|
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
|
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
|
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
|
if responsesResponse.Usage.InputTokensDetails != nil {
|
|
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
}
|
|
// 解析 Tools 用量
|
|
for _, tool := range responsesResponse.Tools {
|
|
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
|
}
|
|
return &usage, nil
|
|
}
|
|
|
|
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
if resp == nil || resp.Body == nil {
|
|
logger.LogError(c, "invalid response or response body")
|
|
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
|
}
|
|
|
|
defer service.CloseResponseBodyGracefully(resp)
|
|
|
|
var usage = &dto.Usage{}
|
|
var responseTextBuilder strings.Builder
|
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
|
// 检查当前数据是否包含 completed 状态和 usage 信息
|
|
var streamResponse dto.ResponsesStreamResponse
|
|
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
|
sendResponsesStreamData(c, streamResponse, data)
|
|
switch streamResponse.Type {
|
|
case "response.completed":
|
|
if streamResponse.Response.Usage != nil {
|
|
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
|
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
|
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
|
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
|
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
}
|
|
case "response.output_text.delta":
|
|
// 处理输出文本
|
|
responseTextBuilder.WriteString(streamResponse.Delta)
|
|
case dto.ResponsesOutputTypeItemDone:
|
|
// 函数调用处理
|
|
if streamResponse.Item != nil {
|
|
switch streamResponse.Item.Type {
|
|
case dto.BuildInCallWebSearchCall:
|
|
info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
if usage.CompletionTokens == 0 {
|
|
// 计算输出文本的 token 数量
|
|
tempStr := responseTextBuilder.String()
|
|
if len(tempStr) > 0 {
|
|
// 非正常结束,使用输出文本的 token 数量
|
|
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
|
usage.CompletionTokens = completionTokens
|
|
}
|
|
}
|
|
|
|
if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
|
|
usage.PromptTokens = usage.CompletionTokens
|
|
} else {
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
}
|
|
|
|
return usage, nil
|
|
}
|