Merge branch 'main' into feat-04
This commit is contained in:
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
|
||||
// message_start, 获取usage
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
|
||||
@@ -72,8 +72,11 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -36,6 +37,47 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"video/flv": true,
|
||||
}
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
flash25LiteMinBudget = 512
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
return flash25LiteMinBudget
|
||||
}
|
||||
if budget > flash25LiteMaxBudget {
|
||||
return flash25LiteMaxBudget
|
||||
}
|
||||
} else if isNew25Pro {
|
||||
if budget < pro25MinBudget {
|
||||
return pro25MinBudget
|
||||
}
|
||||
if budget > pro25MaxBudget {
|
||||
return pro25MaxBudget
|
||||
}
|
||||
} else { // 其他模型
|
||||
if budget < 0 {
|
||||
return 0
|
||||
}
|
||||
if budget > flash25MaxBudget {
|
||||
return flash25MaxBudget
|
||||
}
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
@@ -57,16 +99,31 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 硬编码不支持 ThinkingBudget 的旧模型
|
||||
modelName := info.OriginModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
|
||||
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
@@ -78,39 +135,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
} else {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
|
||||
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens == 0 || budgetTokens < 128 {
|
||||
budgetTokens = 128
|
||||
} else if budgetTokens > 32768 {
|
||||
budgetTokens = 32768
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
}
|
||||
}
|
||||
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if !isNew25Pro {
|
||||
// 只有非新2.5pro模型才支持-nothinking
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro && !is25FlashLite {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
@@ -283,7 +315,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
|
||||
url := part.GetImageMedia().Url
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
@@ -611,9 +644,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
@@ -754,7 +787,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
@@ -849,7 +882,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
|
||||
@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == constant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
// 2025年5月10日后创建的渠道不移除.
|
||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -345,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*")
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
@@ -365,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
|
||||
@@ -98,7 +98,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
|
||||
@@ -61,6 +61,7 @@ type RelayInfo struct {
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
UserGroup string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
@@ -204,6 +205,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
|
||||
@@ -2,14 +2,20 @@ package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
@@ -17,18 +23,50 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
GroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: 1.0, // default user group ratio
|
||||
}
|
||||
|
||||
// check auto group
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("final group: %s", autoGroup))
|
||||
}
|
||||
relayInfo.Group = autoGroup.(string)
|
||||
}
|
||||
|
||||
// check user group special ratio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
// user group special ratio
|
||||
groupRatioInfo.GroupSpecialRatio = userGroupRatio
|
||||
groupRatioInfo.GroupRatio = userGroupRatio
|
||||
} else {
|
||||
// normal group ratio
|
||||
groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
|
||||
}
|
||||
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -58,17 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
|
||||
@@ -136,6 +136,20 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
// Clean up empty system instruction
|
||||
if req.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
for _, part := range req.SystemInstructions.Parts {
|
||||
if part.Text != "" {
|
||||
hasContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContent {
|
||||
req.SystemInstructions = nil
|
||||
}
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
@@ -361,7 +362,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
@@ -510,7 +511,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user