diff --git a/common/custom-event.go b/common/custom-event.go index 69da4bc4..d8f9ec9f 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer( "\r", "\\r") var dataReplacer = strings.NewReplacer( - "\n", "\ndata:", + "\n", "\n", "\r", "\\r") type CustomEvent struct { diff --git a/controller/relay.go b/controller/relay.go index 460599b5..fb4c524f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -148,6 +148,50 @@ func WssRelay(c *gin.Context) { } } +func RelayClaude(c *gin.Context) { + //relayMode := constant.Path2RelayMode(c.Request.URL.Path) + requestId := c.GetString(common.RequestIdKey) + group := c.GetString("group") + originalModel := c.GetString("original_model") + var claudeErr *dto.ClaudeErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) + if err != nil { + common.LogError(c, err.Error()) + claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + break + } + + claudeErr = claudeRequest(c, channel) + + if claudeErr == nil { + return // 成功处理请求,直接返回 + } + + openaiErr := service.ClaudeErrorToOpenAIError(claudeErr) + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break + } + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c, retryLogStr) + } + + if claudeErr != nil { + claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId) + c.JSON(claudeErr.StatusCode, gin.H{ + "type": "error", + "error": claudeErr.Error, + }) + } +} + func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) @@ -162,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode return relay.WssHelper(c, ws) } +func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relay.ClaudeHelper(c) +} + func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) diff --git a/dto/claude.go b/dto/claude.go new file mode 100644 index 00000000..60f638f6 --- /dev/null +++ b/dto/claude.go @@ -0,0 +1,186 @@ +package dto + +import "encoding/json" + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeMediaMessage struct { + Type string `json:"type"` + Text *string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson string `json:"partial_json,omitempty"` + Role string `json:"role,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` +} + +func (c *ClaudeMediaMessage) SetText(s string) { + c.Text = &s +} + +func (c *ClaudeMediaMessage) GetText() string { + if c.Text == nil { + return "" + } + return *c.Text +} + +type ClaudeMessageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data any `json:"data"` +} + +type ClaudeMessage struct { + Role string `json:"role"` + Content any `json:"content"` +} + +func (c *ClaudeMessage) IsStringContent() bool { + _, ok := c.Content.(string) + return ok +} + +func (c *ClaudeMessage) GetStringContent() string { + if c.IsStringContent() { + return c.Content.(string) + } + return "" +} + +func (c *ClaudeMessage) SetStringContent(content string) { + c.Content = content +} + +func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) { + // map content to []ClaudeMediaMessage + // parse to json + jsonContent, _ := json.Marshal(c.Content) + var contentList []ClaudeMediaMessage + err := json.Unmarshal(jsonContent, &contentList) + if err != nil { + return make([]ClaudeMediaMessage, 0), err + } + return contentList, nil +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + System any `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *Thinking `json:"thinking,omitempty"` +} + +type Thinking struct { + Type string `json:"type"` + BudgetTokens int `json:"budget_tokens"` +} + +func (c *ClaudeRequest) IsStringSystem() bool { + _, ok := c.System.(string) + return ok +} + +func (c *ClaudeRequest) GetStringSystem() string { + if c.IsStringSystem() { + return c.System.(string) + } + return "" +} + +func (c *ClaudeRequest) SetStringSystem(system string) { + c.System = system +} + +func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage { + // map content to []ClaudeMediaMessage + // parse to json + jsonContent, _ := json.Marshal(c.System) + var contentList []ClaudeMediaMessage + if err := json.Unmarshal(jsonContent, &contentList); err == nil { + return contentList + } + return make([]ClaudeMediaMessage, 0) +} + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ClaudeErrorWithStatusCode struct { + Error ClaudeError `json:"error"` + StatusCode int `json:"status_code"` + LocalError bool +} + +type ClaudeResponse struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Role string `json:"role,omitempty"` + Content []ClaudeMediaMessage `json:"content,omitempty"` + Completion string `json:"completion,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + Model string `json:"model,omitempty"` + Error *ClaudeError `json:"error,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + Index *int `json:"index,omitempty"` + ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` + Delta *ClaudeMediaMessage `json:"delta,omitempty"` + Message *ClaudeMediaMessage `json:"message,omitempty"` +} + +// set index +func (c *ClaudeResponse) SetIndex(i int) { + c.Index = &i +} + +// get index +func (c *ClaudeResponse) GetIndex() int { + if c.Index == nil { + return 0 + } + return *c.Index +} + +type ClaudeUsage struct { + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/dto/openai_response.go b/dto/openai_response.go index 9188fad7..4097db55 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -125,6 +125,20 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) IsToolCall() bool { + if len(c.Choices) == 0 { + return false + } + return len(c.Choices[0].Delta.ToolCalls) > 0 +} + +func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse { + if c.IsToolCall() { + return &c.Choices[0].Delta.ToolCalls[0] + } + return nil +} + func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) copy(choices, c.Choices) diff --git a/dto/realtime.go b/dto/realtime.go index e28d813e..8c6e8932 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -44,10 +44,11 @@ type RealtimeUsage struct { } type InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - TextTokens int `json:"text_tokens"` - AudioTokens int `json:"audio_tokens"` - ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` + CachedCreationTokens int + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` } type OutputTokenDetails struct { diff --git a/go.mod b/go.mod index ca526466..ce768bf3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b + github.com/bytedance/sonic v1.11.6 github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -42,7 +43,6 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect - github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect diff --git a/middleware/auth.go b/middleware/auth.go index a589f52c..fece4553 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) { } c.Request.Header.Set("Authorization", "Bearer "+key) } + // 检查path包含/v1/messages + if strings.Contains(c.Request.URL.Path, "/v1/messages") { + // 从x-api-key中获取key + key := c.Request.Header.Get("x-api-key") + if key != "" { + c.Request.Header.Set("Authorization", "Bearer "+key) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index c970fd48..9f449b54 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -22,6 +22,7 @@ type Adaptor interface { DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string GetChannelName() string + ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) } type TaskAdaptor interface { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 32be399b..9d3ee99f 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 7f2a2841..e735ee2b 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,6 +20,10 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -48,7 +52,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } - var claudeReq *claude.ClaudeRequest + var claudeReq *dto.ClaudeRequest var err error claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) if err != nil { diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 3b615134..0188c30a 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -1,25 +1,25 @@ package aws import ( - "one-api/relay/channel/claude" + "one-api/dto" ) type AwsClaudeRequest struct { // AnthropicVersion should be "bedrock-2023-05-31" - AnthropicVersion string `json:"anthropic_version"` - System string `json:"system,omitempty"` - Messages []claude.ClaudeMessage `json:"messages"` - MaxTokens uint `json:"max_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *claude.Thinking `json:"thinking,omitempty"` + AnthropicVersion string `json:"anthropic_version"` + System any `json:"system,omitempty"` + Messages []dto.ClaudeMessage `json:"messages"` + MaxTokens uint `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *dto.Thinking `json:"thinking,omitempty"` } -func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { +func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { return &AwsClaudeRequest{ AnthropicVersion: "bedrock-2023-05-31", System: req.System, diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index e1270606..0d517256 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -9,7 +9,7 @@ import ( "io" "net/http" "one-api/common" - relaymodel "one-api/dto" + "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -39,10 +39,10 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime. return client, nil } -func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode { - return &relaymodel.OpenAIErrorWithStatusCode{ +func wrapErr(err error) *dto.OpenAIErrorWithStatusCode { + return &dto.OpenAIErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, - Error: relaymodel.OpenAIError{ + Error: dto.OpenAIError{ Message: fmt.Sprintf("%s", err.Error()), }, } @@ -56,7 +56,7 @@ func awsModelID(requestModel string) (string, error) { return requestModel, nil } -func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { return wrapErr(errors.Wrap(err, "newAwsClient")), nil @@ -77,7 +77,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReq_.(*claude.ClaudeRequest) + claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { @@ -89,14 +89,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return wrapErr(errors.Wrap(err, "InvokeModel")), nil } - claudeResponse := new(claude.ClaudeResponse) + claudeResponse := new(dto.ClaudeResponse) err = json.Unmarshal(awsResp.Body, claudeResponse) if err != nil { return wrapErr(errors.Wrap(err, "unmarshal response")), nil } openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) - usage := relaymodel.Usage{ + usage := dto.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, CompletionTokens: claudeResponse.Usage.OutputTokens, TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, @@ -107,7 +107,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return nil, &usage } -func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { +func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { return wrapErr(errors.Wrap(err, "newAwsClient")), nil @@ -128,7 +128,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReq_.(*claude.ClaudeRequest) + claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) @@ -149,7 +149,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, - Usage: &relaymodel.Usage{}, + Usage: &dto.Usage{}, } isFirst := true c.Stream(func(w io.Writer) bool { @@ -164,7 +164,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel isFirst = false info.FirstResponseTime = time.Now() } - claudeResponse := new(claude.ClaudeResponse) + claudeResponse := new(dto.ClaudeResponse) err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 46a1f964..105f2a9b 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index fd25ecc1..855ed717 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index bf03e5f5..a5c475fa 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -22,6 +22,10 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 9532ca74..89415868 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -1,94 +1,95 @@ package claude -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeMediaMessage struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - PartialJson string `json:"partial_json,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - Delta string `json:"delta,omitempty"` - // tool_calls - Id string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` - Content string `json:"content,omitempty"` - ToolUseId string `json:"tool_use_id,omitempty"` -} - -type ClaudeMessageSource struct { - Type string `json:"type"` - MediaType string `json:"media_type"` - Data string `json:"data"` -} - -type ClaudeMessage struct { - Role string `json:"role"` - Content any `json:"content"` -} - -type Tool struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - InputSchema map[string]interface{} `json:"input_schema"` -} - -type InputSchema struct { - Type string `json:"type"` - Properties any `json:"properties,omitempty"` - Required any `json:"required,omitempty"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - System string `json:"system,omitempty"` - Messages []ClaudeMessage `json:"messages,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *Thinking `json:"thinking,omitempty"` -} - -type Thinking struct { - Type string `json:"type"` - BudgetTokens int `json:"budget_tokens"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Id string `json:"id"` - Type string `json:"type"` - Content []ClaudeMediaMessage `json:"content"` - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage ClaudeUsage `json:"usage"` - Index int `json:"index"` // stream only - ContentBlock *ClaudeMediaMessage `json:"content_block"` - Delta *ClaudeMediaMessage `json:"delta"` // stream only - Message *ClaudeResponse `json:"message"` // stream only: message_start -} - -type ClaudeUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} +// +//type ClaudeMetadata struct { +// UserId string `json:"user_id"` +//} +// +//type ClaudeMediaMessage struct { +// Type string `json:"type"` +// Text string `json:"text,omitempty"` +// Source *ClaudeMessageSource `json:"source,omitempty"` +// Usage *ClaudeUsage `json:"usage,omitempty"` +// StopReason *string `json:"stop_reason,omitempty"` +// PartialJson string `json:"partial_json,omitempty"` +// Thinking string `json:"thinking,omitempty"` +// Signature string `json:"signature,omitempty"` +// Delta string `json:"delta,omitempty"` +// // tool_calls +// Id string `json:"id,omitempty"` +// Name string `json:"name,omitempty"` +// Input any `json:"input,omitempty"` +// Content string `json:"content,omitempty"` +// ToolUseId string `json:"tool_use_id,omitempty"` +//} +// +//type ClaudeMessageSource struct { +// Type string `json:"type"` +// MediaType string `json:"media_type"` +// Data string `json:"data"` +//} +// +//type ClaudeMessage struct { +// Role string `json:"role"` +// Content any `json:"content"` +//} +// +//type Tool struct { +// Name string `json:"name"` +// Description string `json:"description,omitempty"` +// InputSchema map[string]interface{} `json:"input_schema"` +//} +// +//type InputSchema struct { +// Type string `json:"type"` +// Properties any `json:"properties,omitempty"` +// Required any `json:"required,omitempty"` +//} +// +//type ClaudeRequest struct { +// Model string `json:"model"` +// Prompt string `json:"prompt,omitempty"` +// System string `json:"system,omitempty"` +// Messages []ClaudeMessage `json:"messages,omitempty"` +// MaxTokens uint `json:"max_tokens,omitempty"` +// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` +// StopSequences []string `json:"stop_sequences,omitempty"` +// Temperature *float64 `json:"temperature,omitempty"` +// TopP float64 `json:"top_p,omitempty"` +// TopK int `json:"top_k,omitempty"` +// //ClaudeMetadata `json:"metadata,omitempty"` +// Stream bool `json:"stream,omitempty"` +// Tools any `json:"tools,omitempty"` +// ToolChoice any `json:"tool_choice,omitempty"` +// Thinking *Thinking `json:"thinking,omitempty"` +//} +// +//type Thinking struct { +// Type string `json:"type"` +// BudgetTokens int `json:"budget_tokens"` +//} +// +//type ClaudeError struct { +// Type string `json:"type"` +// Message string `json:"message"` +//} +// +//type ClaudeResponse struct { +// Id string `json:"id"` +// Type string `json:"type"` +// Content []ClaudeMediaMessage `json:"content"` +// Completion string `json:"completion"` +// StopReason string `json:"stop_reason"` +// Model string `json:"model"` +// Error ClaudeError `json:"error"` +// Usage ClaudeUsage `json:"usage"` +// Index int `json:"index"` // stream only +// ContentBlock *ClaudeMediaMessage `json:"content_block"` +// Delta *ClaudeMediaMessage `json:"delta"` // stream only +// Message *ClaudeResponse `json:"message"` // stream only: message_start +//} +// +//type ClaudeUsage struct { +// InputTokens int `json:"input_tokens"` +// OutputTokens int `json:"output_tokens"` +//} diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 011694df..74b73454 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -30,9 +30,9 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { +func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest { - claudeRequest := ClaudeRequest{ + claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, Prompt: "", StopSequences: nil, @@ -61,12 +61,12 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR return &claudeRequest } -func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { - claudeTools := make([]Tool, 0, len(textRequest.Tools)) +func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { + claudeTools := make([]dto.Tool, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { if params, ok := tool.Function.Parameters.(map[string]any); ok { - claudeTool := Tool{ + claudeTool := dto.Tool{ Name: tool.Function.Name, Description: tool.Function.Description, } @@ -84,7 +84,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } } - claudeRequest := ClaudeRequest{ + claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, StopSequences: nil, @@ -108,7 +108,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } // BudgetTokens 为 max_tokens 的 80% - claudeRequest.Thinking = &Thinking{ + claudeRequest.Thinking = &dto.Thinking{ Type: "enabled", BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), } @@ -166,7 +166,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR lastMessage = fmtMessage } - claudeMessages := make([]ClaudeMessage, 0) + claudeMessages := make([]dto.ClaudeMessage, 0) isFirstMessage := true for _, message := range formatMessages { if message.Role == "system" { @@ -187,63 +187,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR isFirstMessage = false if message.Role != "user" { // fix: first message is assistant, add user message - claudeMessage := ClaudeMessage{ + claudeMessage := dto.ClaudeMessage{ Role: "user", - Content: []ClaudeMediaMessage{ + Content: []dto.ClaudeMediaMessage{ { Type: "text", - Text: "...", + Text: common.GetPointer[string]("..."), }, }, } claudeMessages = append(claudeMessages, claudeMessage) } } - claudeMessage := ClaudeMessage{ + claudeMessage := dto.ClaudeMessage{ Role: message.Role, } if message.Role == "tool" { if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { lastMessage := claudeMessages[len(claudeMessages)-1] if content, ok := lastMessage.Content.(string); ok { - lastMessage.Content = []ClaudeMediaMessage{ + lastMessage.Content = []dto.ClaudeMediaMessage{ { Type: "text", - Text: content, + Text: common.GetPointer[string](content), }, } } - lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ + lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{ Type: "tool_result", ToolUseId: message.ToolCallId, - Content: message.StringContent(), + Content: message.Content, }) claudeMessages[len(claudeMessages)-1] = lastMessage continue } else { claudeMessage.Role = "user" - claudeMessage.Content = []ClaudeMediaMessage{ + claudeMessage.Content = []dto.ClaudeMediaMessage{ { Type: "tool_result", ToolUseId: message.ToolCallId, - Content: message.StringContent(), + Content: message.Content, }, } } } else if message.IsStringContent() && message.ToolCalls == nil { claudeMessage.Content = message.StringContent() } else { - claudeMediaMessages := make([]ClaudeMediaMessage, 0) + claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0) for _, mediaMessage := range message.ParseContent() { - claudeMediaMessage := ClaudeMediaMessage{ + claudeMediaMessage := dto.ClaudeMediaMessage{ Type: mediaMessage.Type, } if mediaMessage.Type == "text" { - claudeMediaMessage.Text = mediaMessage.Text + claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text) } else { imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) claudeMediaMessage.Type = "image" - claudeMediaMessage.Source = &ClaudeMessageSource{ + claudeMediaMessage.Source = &dto.ClaudeMessageSource{ Type: "base64", } // 判断是否是url @@ -273,7 +273,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } - claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ + claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ Type: "tool_use", Id: toolCall.ID, Name: toolCall.Function.Name, @@ -291,7 +291,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model @@ -329,8 +329,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d } } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta != nil { - choice.Index = claudeResponse.Index - choice.Delta.SetContentString(claudeResponse.Delta.Text) + choice.Index = *claudeResponse.Index + choice.Delta.Content = claudeResponse.Delta.Text switch claudeResponse.Delta.Type { case "input_json_delta": tools = append(tools, dto.ToolCallResponse{ @@ -368,7 +368,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d return &response } -func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { +func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse { choices := make([]dto.OpenAITextResponseChoice, 0) fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -377,7 +377,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope } var responseText string if len(claudeResponse.Content) > 0 { - responseText = claudeResponse.Content[0].Text + responseText = *claudeResponse.Content[0].Text } tools := make([]dto.ToolCallResponse, 0) thinkingContent := "" @@ -412,7 +412,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope // 加密的不管, 只输出明文的推理过程 thinkingContent = message.Thinking case "text": - responseText = message.Text + responseText = *message.Text } } } @@ -442,7 +442,7 @@ type ClaudeResponseInfo struct { Usage *dto.Usage } -func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { +func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { if requestMode == RequestModeCompletion { claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { @@ -452,7 +452,9 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o claudeInfo.Model = claudeResponse.Message.Model claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens } else if claudeResponse.Type == "content_block_delta" { - claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text) + if claudeResponse.Delta.Text != nil { + claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) + } } else if claudeResponse.Type == "message_delta" { claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens @@ -470,6 +472,61 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o } func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + + if info.RelayFormat == relaycommon.RelayFormatOpenAI { + return toOpenAIStreamHandler(c, resp, info, requestMode) + } + + usage := &dto.Usage{} + responseText := strings.Builder{} + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var claudeResponse dto.ClaudeResponse + err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if requestMode == RequestModeCompletion { + responseText.WriteString(claudeResponse.Completion) + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + info.UpstreamModelName = claudeResponse.Message.Model + usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens + usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens + usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens + } else if claudeResponse.Type == "content_block_delta" { + responseText.WriteString(claudeResponse.Delta.GetText()) + } else if claudeResponse.Type == "message_delta" { + if claudeResponse.Usage.InputTokens > 0 { + // 不叠加,只取最新的 + usage.PromptTokens = claudeResponse.Usage.InputTokens + } + usage.CompletionTokens = claudeResponse.Usage.OutputTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + } + helper.ClaudeChunkData(c, claudeResponse, data) + return true + }) + + if requestMode == RequestModeCompletion { + usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 说明流模式建立失败,可能为官方出错 + if usage.PromptTokens == 0 { + //usage.PromptTokens = info.PromptTokens + } + if usage.CompletionTokens == 0 { + usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens) + } + } + return nil, usage +} + +func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) claudeInfo := &ClaudeResponseInfo{ ResponseId: responseId, @@ -480,7 +537,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var claudeResponse ClaudeResponse + var claudeResponse dto.ClaudeResponse err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -530,7 +587,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var claudeResponse ClaudeResponse + var claudeResponse dto.ClaudeResponse err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -546,13 +603,12 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r StatusCode: resp.StatusCode, }, nil } - fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil - } usage := dto.Usage{} if requestMode == RequestModeCompletion { + completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) + if err != nil { + return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil + } usage.PromptTokens = info.PromptTokens usage.CompletionTokens = completionTokens usage.TotalTokens = info.PromptTokens + completionTokens @@ -560,14 +616,23 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r usage.PromptTokens = claudeResponse.Usage.InputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens + usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens } - fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + var responseData []byte + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) + openaiResponse.Usage = usage + responseData, err = json.Marshal(openaiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + case relaycommon.RelayFormatClaude: + responseData = responseBody } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + _, err = c.Writer.Write(responseData) return nil, &usage } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 5c2eadc2..b21e25f3 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -17,6 +17,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index d552a53b..7675d546 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index d779ee65..ad01b8f4 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 2626dd7d..96aff447 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -23,6 +23,12 @@ type Adaptor struct { BotType int } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 1b7131dc..a629968b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -21,6 +21,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 77076bd4..bcfc8dea 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index fcea169a..80547346 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -14,6 +14,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 9670ec94..151072cb 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -73,13 +79,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - + switch info.RelayMode { case constant.RelayModeEmbeddings: err, usage = mokaEmbeddingHandler(c, resp) default: // err, usage = mokaHandler(c, resp) - + } return } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 7e1c6237..4190dd3f 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d8a44335..196343e8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -29,6 +29,12 @@ type Adaptor struct { ResponseFormat string } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go index 83afb6af..aef5afeb 100644 --- a/relay/channel/openrouter/adaptor.go +++ b/relay/channel/openrouter/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index f38fa95b..69ef5001 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 2b27bdb1..de84406c 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 797f0244..754a1f00 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 768ef646..28a02aae 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -23,6 +23,12 @@ type Adaptor struct { Timestamp int64 } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 7ccd3f30..2f348e46 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -38,6 +38,9 @@ type Adaptor struct { AccountCredentials Credentials } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go index 4ba570de..4a571612 100644 --- a/relay/channel/vertex/dto.go +++ b/relay/channel/vertex/dto.go @@ -1,25 +1,25 @@ package vertex import ( - "one-api/relay/channel/claude" + "one-api/dto" ) type VertexAIClaudeRequest struct { - AnthropicVersion string `json:"anthropic_version"` - Messages []claude.ClaudeMessage `json:"messages"` - System any `json:"system,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *claude.Thinking `json:"thinking,omitempty"` + AnthropicVersion string `json:"anthropic_version"` + Messages []dto.ClaudeMessage `json:"messages"` + System any `json:"system,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *dto.Thinking `json:"thinking,omitempty"` } -func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest { +func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest { return &VertexAIClaudeRequest{ AnthropicVersion: version, System: req.System, diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 3b57c67c..f423d587 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -17,6 +17,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 71fd1367..d66f3732 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,6 +16,12 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 87ff20d5..aa612f0c 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,6 +14,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 5983c1d9..7a23e212 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -58,7 +64,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go new file mode 100644 index 00000000..97de772b --- /dev/null +++ b/relay/claude_handler.go @@ -0,0 +1,162 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting/model_setting" + "strings" +) + +func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + return textRequest, nil +} + +func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { + + relayInfo := relaycommon.GenRelayInfoClaude(c) + + // get & validate textRequest 获取并验证文本请求 + textRequest, err := getAndValidateClaudeRequest(c) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest) + } + + if textRequest.Stream { + relayInfo.IsStream = true + } + + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + } + + textRequest.Model = relayInfo.UpstreamModelName + + promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) + // count messages token error 计算promptTokens错误 + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + + if openaiErr != nil { + return service.OpenAIErrorToClaudeError(openaiErr) + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + var requestBody io.Reader + + if textRequest.MaxTokens == 0 { + textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) + } + + if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && + strings.HasSuffix(textRequest.Model, "-thinking") { + if textRequest.Thinking == nil { + // 因为BudgetTokens 必须大于1024 + if textRequest.MaxTokens < 1280 { + textRequest.MaxTokens = 1280 + } + + // BudgetTokens 为 max_tokens 的 80% + textRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), + } + // TODO: 临时处理 + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking + textRequest.TopP = 0 + textRequest.Temperature = common.GetPointer[float64](1.0) + } + textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") + relayInfo.UpstreamModelName = textRequest.Model + } + + convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) + + //log.Printf("requestBody: %s", requestBody) + + statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + if resp != nil { + httpResp = resp.(*http.Response) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp, false) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return service.OpenAIErrorToClaudeError(openaiErr) + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + //log.Printf("usage: %v", usage) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return service.OpenAIErrorToClaudeError(openaiErr) + } + service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} + +func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { + var promptTokens int + var err error + switch info.RelayMode { + default: + promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) + } + info.PromptTokens = promptTokens + return promptTokens, err +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index c1d3f4a4..3b5ef795 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,11 @@ type ThinkingContentInfo struct { SendLastThinkingContent bool } +const ( + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" +) + type RelayInfo struct { ChannelType int ChannelId int @@ -58,6 +63,8 @@ type RelayInfo struct { UserSetting map[string]interface{} UserEmail string UserQuota int + RelayFormat string + ResponseTimes int64 ThinkingContentInfo } @@ -82,6 +89,13 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } +func GenRelayInfoClaude(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatClaude + info.ShouldIncludeUsage = false + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -123,6 +137,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), ChannelSetting: channelSetting, + RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -157,6 +172,7 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { + info.ResponseTimes++ if info.isFirstResponse { info.FirstResponseTime = time.Now() info.isFirstResponse = false diff --git a/relay/helper/common.go b/relay/helper/common.go index 2a72d30a..6af55a86 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -19,6 +19,14 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") diff --git a/relay/helper/price.go b/relay/helper/price.go index b169df98..1ae3d2fc 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -16,6 +16,7 @@ type PriceData struct { CacheRatio float64 GroupRatio float64 UsePrice bool + CacheCreationRatio float64 ShouldPreConsumedQuota int } @@ -26,6 +27,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var modelRatio float64 var completionRatio float64 var cacheRatio float64 + var cacheCreationRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota if maxTokens != 0 { @@ -42,6 +44,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName) cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) + cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -54,6 +57,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens GroupRatio: groupRatio, UsePrice: usePrice, CacheRatio: cacheRatio, + CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, }, nil } diff --git a/router/relay-router.go b/router/relay-router.go index 32e0c682..3a9122d4 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -35,6 +35,7 @@ func SetRelayRouter(router *gin.Engine) { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) + httpRouter.POST("/messages", controller.RelayClaude) httpRouter.POST("/completions", controller.Relay) httpRouter.POST("/chat/completions", controller.Relay) httpRouter.POST("/edits", controller.Relay) diff --git a/service/convert.go b/service/convert.go new file mode 100644 index 00000000..c4916df2 --- /dev/null +++ b/service/convert.go @@ -0,0 +1,310 @@ +package service + +import ( + "encoding/json" + "fmt" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" +) + +func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) { + openAIRequest := dto.GeneralOpenAIRequest{ + Model: claudeRequest.Model, + MaxTokens: claudeRequest.MaxTokens, + Temperature: claudeRequest.Temperature, + TopP: claudeRequest.TopP, + Stream: claudeRequest.Stream, + } + + // Convert stop sequences + if len(claudeRequest.StopSequences) == 1 { + openAIRequest.Stop = claudeRequest.StopSequences[0] + } else if len(claudeRequest.StopSequences) > 1 { + openAIRequest.Stop = claudeRequest.StopSequences + } + + // Convert tools + tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools) + openAITools := make([]dto.ToolCallRequest, 0) + for _, claudeTool := range tools { + openAITool := dto.ToolCallRequest{ + Type: "function", + Function: dto.FunctionRequest{ + Name: claudeTool.Name, + Description: claudeTool.Description, + Parameters: claudeTool.InputSchema, + }, + } + openAITools = append(openAITools, openAITool) + } + openAIRequest.Tools = openAITools + + // Convert messages + openAIMessages := make([]dto.Message, 0) + + // Add system message if present + if claudeRequest.IsStringSystem() { + openAIMessage := dto.Message{ + Role: "system", + } + openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) + openAIMessages = append(openAIMessages, openAIMessage) + } else { + systems := claudeRequest.ParseSystem() + if len(systems) > 0 { + systemStr := "" + openAIMessage := dto.Message{ + Role: "system", + } + for _, system := range systems { + systemStr += system.Type + } + openAIMessage.SetStringContent(systemStr) + openAIMessages = append(openAIMessages, openAIMessage) + } + } + for _, claudeMessage := range claudeRequest.Messages { + openAIMessage := dto.Message{ + Role: claudeMessage.Role, + } + + //log.Printf("claudeMessage.Content: %v", claudeMessage.Content) + if claudeMessage.IsStringContent() { + openAIMessage.SetStringContent(claudeMessage.GetStringContent()) + } else { + content, err := claudeMessage.ParseContent() + if err != nil { + return nil, err + } + contents := content + var toolCalls []dto.ToolCallRequest + mediaMessages := make([]dto.MediaContent, 0, len(contents)) + + for _, mediaMsg := range contents { + switch mediaMsg.Type { + case "text": + message := dto.MediaContent{ + Type: "text", + Text: mediaMsg.GetText(), + } + mediaMessages = append(mediaMessages, message) + case "image": + // Handle image conversion (base64 to URL or keep as is) + imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data) + //textContent += fmt.Sprintf("[Image: %s]", imageData) + mediaMessage := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{Url: imageData}, + } + mediaMessages = append(mediaMessages, mediaMessage) + case "tool_use": + toolCall := dto.ToolCallRequest{ + ID: mediaMsg.Id, + Function: dto.FunctionRequest{ + Name: mediaMsg.Name, + Arguments: toJSONString(mediaMsg.Input), + }, + } + toolCalls = append(toolCalls, toolCall) + case "tool_result": + // Add tool result as a separate message + oaiToolMessage := dto.Message{ + Role: "tool", + ToolCallId: mediaMsg.ToolUseId, + } + oaiToolMessage.Content = mediaMsg.Content + } + } + + openAIMessage.SetMediaContent(mediaMessages) + + if len(toolCalls) > 0 { + openAIMessage.SetToolCalls(toolCalls) + } + } + + openAIMessages = append(openAIMessages, openAIMessage) + } + + openAIRequest.Messages = openAIMessages + + return &openAIRequest, nil +} + +func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode { + claudeError := dto.ClaudeError{ + Type: "new_api_error", + Message: openAIError.Error.Message, + } + return &dto.ClaudeErrorWithStatusCode{ + Error: claudeError, + StatusCode: openAIError.StatusCode, + } +} + +func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode { + openAIError := dto.OpenAIError{ + Message: claudeError.Error.Message, + Type: "new_api_error", + } + return &dto.OpenAIErrorWithStatusCode{ + Error: openAIError, + StatusCode: claudeError.StatusCode, + } +} + +func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse { + var claudeResponses []*dto.ClaudeResponse + if info.ResponseTimes == 1 { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_start", + Message: &dto.ClaudeMediaMessage{ + Id: openAIResponse.Id, + Model: openAIResponse.Model, + Type: "message", + Role: "assistant", + Usage: &dto.ClaudeUsage{ + InputTokens: info.PromptTokens, + OutputTokens: 0, + }, + }, + }) + if openAIResponse.IsToolCall() { + resp := &dto.ClaudeResponse{ + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + }, + } + resp.SetIndex(0) + claudeResponses = append(claudeResponses, resp) + } else { + resp := &dto.ClaudeResponse{ + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](""), + }, + } + resp.SetIndex(0) + claudeResponses = append(claudeResponses, resp) + } + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "ping", + }) + return claudeResponses + } + + if len(openAIResponse.Choices) == 0 { + // no choices + // TODO: handle this case + } else { + chosenChoice := openAIResponse.Choices[0] + if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { + // should be done + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "content_block_stop", + Index: common.GetPointer[int](0), + }) + if openAIResponse.Usage != nil { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_delta", + Usage: &dto.ClaudeUsage{ + InputTokens: openAIResponse.Usage.PromptTokens, + OutputTokens: openAIResponse.Usage.CompletionTokens, + }, + Delta: &dto.ClaudeMediaMessage{ + StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)), + }, + }) + } + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_stop", + }) + } else { + var claudeResponse dto.ClaudeResponse + claudeResponse.SetIndex(0) + claudeResponse.Type = "content_block_delta" + if len(chosenChoice.Delta.ToolCalls) > 0 { + // tools delta + claudeResponse.Delta = &dto.ClaudeMediaMessage{ + Type: "input_json_delta", + PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments, + } + } else { + // text delta + claudeResponse.Delta = &dto.ClaudeMediaMessage{ + Type: "text_delta", + Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()), + } + } + claudeResponses = append(claudeResponses, &claudeResponse) + } + } + + return claudeResponses +} + +func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse { + var stopReason string + contents := make([]dto.ClaudeMediaMessage, 0) + claudeResponse := &dto.ClaudeResponse{ + Id: openAIResponse.Id, + Type: "message", + Role: "assistant", + Model: openAIResponse.Model, + } + for _, choice := range openAIResponse.Choices { + stopReason = stopReasonOpenAI2Claude(choice.FinishReason) + claudeContent := dto.ClaudeMediaMessage{} + if choice.FinishReason == "tool_calls" { + claudeContent.Type = "tool_use" + claudeContent.Id = choice.Message.ToolCallId + claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name + var mapParams map[string]interface{} + if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil { + claudeContent.Input = mapParams + } else { + claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments + } + } else { + claudeContent.Type = "text" + claudeContent.SetText(choice.Message.StringContent()) + } + contents = append(contents, claudeContent) + } + claudeResponse.Content = contents + claudeResponse.StopReason = stopReason + claudeResponse.Usage = &dto.ClaudeUsage{ + InputTokens: openAIResponse.PromptTokens, + OutputTokens: openAIResponse.CompletionTokens, + } + + return claudeResponse +} + +func stopReasonOpenAI2Claude(reason string) string { + switch reason { + case "stop": + return "end_turn" + case "stop_sequence": + return "stop_sequence" + case "max_tokens": + return "max_tokens" + case "tool_calls": + return "tool_use" + default: + return reason + } +} + +func toJSONString(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return "{}" + } + return string(b) +} diff --git a/service/error.go b/service/error.go index 82fbda18..9824a853 100644 --- a/service/error.go +++ b/service/error.go @@ -50,6 +50,30 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI return openaiErr } +func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { + text := err.Error() + lowerText := strings.ToLower(text) + if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } + claudeError := dto.ClaudeError{ + Message: text, + Type: "new_api_error", + //Code: code, + } + return &dto.ClaudeErrorWithStatusCode{ + Error: claudeError, + StatusCode: statusCode, + } +} + +func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { + claudeErr := ClaudeErrorWrapper(err, code, statusCode) + claudeErr.LocalError = true + return claudeErr +} + func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 6406cbe1..75457b97 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -53,3 +53,12 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, info["audio_completion_ratio"] = audioCompletionRatio return info } + +func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64, + cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} { + info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice) + info["claude"] = true + info["cache_creation_tokens"] = cacheCreationTokens + info["cache_creation_ratio"] = cacheCreationRatio + return info +} diff --git a/service/quota.go b/service/quota.go index e19f1b82..ec5af57a 100644 --- a/service/quota.go +++ b/service/quota.go @@ -194,6 +194,75 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + modelName := relayInfo.OriginModelName + + tokenName := ctx.GetString("token_name") + completionRatio := priceData.CompletionRatio + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + + cacheRatio := priceData.CacheRatio + cacheTokens := usage.PromptTokensDetails.CachedTokens + + cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + + calculateQuota := 0.0 + if !priceData.UsePrice { + calculateQuota = float64(promptTokens) + calculateQuota += float64(cacheTokens) * cacheRatio + calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio + calculateQuota += float64(completionTokens) * completionRatio + calculateQuota = calculateQuota * groupRatio * modelRatio + } else { + calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio + } + + if modelRatio != 0 && calculateQuota <= 0 { + calculateQuota = 1 + } + + quota := int(calculateQuota) + + totalTokens := promptTokens + completionTokens + + var logContent string + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游出错)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + } else { + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) +} + func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { diff --git a/service/token_counter.go b/service/token_counter.go index a6b8e86a..98386f85 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "errors" "fmt" "image" @@ -192,6 +193,110 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA return tkm, nil } +func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { + tkm := 0 + + // Count tokens in messages + msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream) + if err != nil { + return 0, err + } + tkm += msgTokens + + // Count tokens in system message + if request.System != "" { + systemTokens, err := CountTokenInput(request.System, model) + if err != nil { + return 0, err + } + tkm += systemTokens + } + + if request.Tools != nil { + // check is array + if tools, ok := request.Tools.([]any); ok { + if len(tools) > 0 { + parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools) + if err1 != nil { + return 0, fmt.Errorf("tools: Input should be a valid list: %v", err) + } + toolTokens, err2 := CountTokenClaudeTools(parsedTools, model) + if err2 != nil { + return 0, fmt.Errorf("tools: %v", err) + } + tkm += toolTokens + } + } else { + return 0, errors.New("tools: Input should be a valid list") + } + } + + return tkm, nil +} + +func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) { + tokenEncoder := getTokenEncoder(model) + tokenNum := 0 + + for _, message := range messages { + // Count tokens for role + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.IsStringContent() { + tokenNum += getTokenNum(tokenEncoder, message.GetStringContent()) + } else { + content, err := message.ParseContent() + if err != nil { + return 0, err + } + for _, mediaMessage := range content { + switch mediaMessage.Type { + case "text": + tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText()) + case "image": + //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream) + //if err != nil { + // return 0, err + //} + tokenNum += 1000 + case "tool_use": + tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name) + inputJSON, _ := json.Marshal(mediaMessage.Input) + tokenNum += getTokenNum(tokenEncoder, string(inputJSON)) + case "tool_result": + contentJSON, _ := json.Marshal(mediaMessage.Content) + tokenNum += getTokenNum(tokenEncoder, string(contentJSON)) + } + } + } + } + + // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting) + tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting + + return tokenNum, nil +} + +func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) { + tokenEncoder := getTokenEncoder(model) + tokenNum := 0 + + for _, tool := range tools { + tokenNum += getTokenNum(tokenEncoder, tool.Name) + tokenNum += getTokenNum(tokenEncoder, tool.Description) + + schemaJSON, err := json.Marshal(tool.InputSchema) + if err != nil { + return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error())) + } + tokenNum += getTokenNum(tokenEncoder, string(schemaJSON)) + } + + // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting) + tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting + + return tokenNum, nil +} + func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) { audioToken := 0 textToken := 0 diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index 98f022ed..01d79c10 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -7,26 +7,45 @@ import ( ) var defaultCacheRatio = map[string]float64{ - "gpt-4": 0.5, - "o1": 0.5, - "o1-2024-12-17": 0.5, - "o1-preview-2024-09-12": 0.5, - "o1-preview": 0.5, - "o1-mini-2024-09-12": 0.5, - "o1-mini": 0.5, - "gpt-4o-2024-11-20": 0.5, - "gpt-4o-2024-08-06": 0.5, - "gpt-4o": 0.5, - "gpt-4o-mini-2024-07-18": 0.5, - "gpt-4o-mini": 0.5, - "gpt-4o-realtime-preview": 0.5, - "gpt-4o-mini-realtime-preview": 0.5, - "deepseek-chat": 0.1, - "deepseek-reasoner": 0.1, - "deepseek-coder": 0.1, + "gpt-4": 0.5, + "o1": 0.5, + "o1-2024-12-17": 0.5, + "o1-preview-2024-09-12": 0.5, + "o1-preview": 0.5, + "o1-mini-2024-09-12": 0.5, + "o1-mini": 0.5, + "gpt-4o-2024-11-20": 0.5, + "gpt-4o-2024-08-06": 0.5, + "gpt-4o": 0.5, + "gpt-4o-mini-2024-07-18": 0.5, + "gpt-4o-mini": 0.5, + "gpt-4o-realtime-preview": 0.5, + "gpt-4o-mini-realtime-preview": 0.5, + "deepseek-chat": 0.25, + "deepseek-reasoner": 0.25, + "deepseek-coder": 0.25, + "claude-3-sonnet-20240229": 0.1, + "claude-3-opus-20240229": 0.1, + "claude-3-haiku-20240307": 0.1, + "claude-3-5-haiku-20241022": 0.1, + "claude-3-5-sonnet-20240620": 0.1, + "claude-3-5-sonnet-20241022": 0.1, + "claude-3-7-sonnet-20250219": 0.1, + "claude-3-7-sonnet-20250219-thinking": 0.1, } -var defaultCreateCacheRatio = map[string]float64{} +var defaultCreateCacheRatio = map[string]float64{ + "claude-3-sonnet-20240229": 1.25, + "claude-3-opus-20240229": 1.25, + "claude-3-haiku-20240307": 1.25, + "claude-3-5-haiku-20241022": 1.25, + "claude-3-5-sonnet-20240620": 1.25, + "claude-3-5-sonnet-20241022": 1.25, + "claude-3-7-sonnet-20250219": 1.25, + "claude-3-7-sonnet-20250219-thinking": 1.25, +} + +//var defaultCreateCacheRatio = map[string]float64{} var cacheRatioMap map[string]float64 var cacheRatioMapMutex sync.RWMutex @@ -69,16 +88,10 @@ func GetCacheRatio(name string) (float64, bool) { return ratio, true } -// DefaultCacheRatio2JSONString converts the default cache ratio map to a JSON string -func DefaultCacheRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultCacheRatio) - if err != nil { - common.SysError("error marshalling default cache ratio: " + err.Error()) +func GetCreateCacheRatio(name string) (float64, bool) { + ratio, ok := defaultCreateCacheRatio[name] + if !ok { + return 1.25, false // Default to 1.25 if not found } - return string(jsonBytes) -} - -// GetDefaultCacheRatioMap returns the default cache ratio map -func GetDefaultCacheRatioMap() map[string]float64 { - return defaultCacheRatio + return ratio, true } diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index fb87d54f..21d0a979 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -26,8 +26,14 @@ import { } from '@douyinfe/semi-ui'; import { ITEMS_PER_PAGE } from '../constants'; import { - renderAudioModelPrice, renderGroup, - renderModelPrice, renderModelPriceSimple, + renderAudioModelPrice, + renderClaudeLogContent, + renderClaudeModelPrice, + renderClaudeModelPriceSimple, + renderGroup, + renderLogContent, + renderModelPrice, + renderModelPriceSimple, renderNumber, renderQuota, stringToColor @@ -564,13 +570,23 @@ const LogsTable = () => { ); } - let content = renderModelPriceSimple( - other.model_ratio, - other.model_price, - other.group_ratio, - other.cache_tokens || 0, - other.cache_ratio || 1.0, - ); + let content = other?.claude + ? renderClaudeModelPriceSimple( + other.model_ratio, + other.model_price, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + other.cache_creation_tokens || 0, + other.cache_creation_ratio || 1.0, + ) + : renderModelPriceSimple( + other.model_ratio, + other.model_price, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + ); return ( { value: other.cache_tokens, }); } - expandDataLocal.push({ - key: t('日志详情'), - value: logs[i].content, - }); + if (other?.cache_creation_tokens > 0) { + expandDataLocal.push({ + key: t('缓存创建 Tokens'), + value: other.cache_creation_tokens, + }); + } + if (logs[i].type === 2) { + expandDataLocal.push({ + key: t('日志详情'), + value: other?.claude + ? renderClaudeLogContent( + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other.user_group_ratio, + other.cache_ratio || 1.0, + other.cache_creation_ratio || 1.0 + ) + : renderLogContent( + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other.user_group_ratio + ), + }); + } if (logs[i].type === 2) { let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== ''; if (modelMapped) { @@ -850,6 +890,19 @@ const LogsTable = () => { other?.cache_tokens || 0, other?.cache_ratio || 1.0, ); + } else if (other?.claude) { + content = renderClaudeModelPrice( + logs[i].prompt_tokens, + logs[i].completion_tokens, + other.model_ratio, + other.model_price, + other.completion_ratio, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + other.cache_creation_tokens || 0, + other.cache_creation_ratio || 1.0, + ); } else { content = renderModelPrice( logs[i].prompt_tokens, diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 3ac81420..d1396191 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -325,9 +325,8 @@ export function renderModelPrice( return ( <>
-

{i18next.t('提示价格:${{price}} = ${{total}} / 1M tokens', { +

{i18next.t('提示价格:${{price}} / 1M tokens', { price: inputRatioPrice, - total: inputRatioPrice })}

{i18next.t('补全价格:${{price}} * {{completionRatio}} = ${{total}} / 1M tokens (补全倍率: {{completionRatio}})', { price: inputRatioPrice, @@ -445,9 +444,8 @@ export function renderAudioModelPrice( return ( <>

-

{i18next.t('提示价格:${{price}} = ${{total}} / 1M tokens', { +

{i18next.t('提示价格:${{price}} / 1M tokens', { price: inputRatioPrice, - total: inputRatioPrice })}

{i18next.t('补全价格:${{price}} * {{completionRatio}} = ${{total}} / 1M tokens (补全倍率: {{completionRatio}})', { price: inputRatioPrice, @@ -654,3 +652,194 @@ export function stringToColor(str) { let i = sum % colors.length; return colors[i]; } + +export function renderClaudeModelPrice( + inputTokens, + completionTokens, + modelRatio, + modelPrice = -1, + completionRatio, + groupRatio, + cacheTokens = 0, + cacheRatio = 1.0, + cacheCreationTokens = 0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格:${{price}} * {{ratioType}}:{{ratio}} = ${{total}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio, + total: modelPrice * groupRatio + }); + } else { + if (completionRatio === undefined) { + completionRatio = 0; + } + + const completionRatioValue = completionRatio || 0; + const inputRatioPrice = modelRatio * 2.0; + const completionRatioPrice = modelRatio * 2.0 * completionRatioValue; + let cacheRatioPrice = (modelRatio * 2.0 * cacheRatio).toFixed(2); + let cacheCreationRatioPrice = modelRatio * 2.0 * cacheCreationRatio; + + // Calculate effective input tokens (non-cached + cached with ratio applied + cache creation with ratio applied) + const nonCachedTokens = inputTokens; + const effectiveInputTokens = nonCachedTokens + + (cacheTokens * cacheRatio) + + (cacheCreationTokens * cacheCreationRatio); + + let price = + (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + + (completionTokens / 1000000) * completionRatioPrice * groupRatio; + + return ( + <> +

+

{i18next.t('提示价格:${{price}} / 1M tokens', { + price: inputRatioPrice, + })}

+

{i18next.t('补全价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens', { + price: inputRatioPrice, + ratio: completionRatio, + total: completionRatioPrice + })}

+ {cacheTokens > 0 && ( +

{i18next.t('缓存价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (缓存倍率: {{cacheRatio}})', { + price: inputRatioPrice, + ratio: cacheRatio, + total: cacheRatioPrice, + cacheRatio: cacheRatio + })}

+ )} + {cacheCreationTokens > 0 && ( +

{i18next.t('缓存创建价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (缓存创建倍率: {{cacheCreationRatio}})', { + price: inputRatioPrice, + ratio: cacheCreationRatio, + total: cacheCreationRatioPrice, + cacheCreationRatio: cacheCreationRatio + })}

+ )} +

+

+ {(cacheTokens > 0 || cacheCreationTokens > 0) ? + i18next.t('提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 缓存创建 {{cacheCreationInput}} tokens / 1M tokens * ${{cacheCreationPrice}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { + nonCacheInput: nonCachedTokens, + cacheInput: cacheTokens, + cacheRatio: cacheRatio, + cacheCreationInput: cacheCreationTokens, + cacheCreationRatio: cacheCreationRatio, + cachePrice: cacheRatioPrice, + cacheCreationPrice: cacheCreationRatioPrice, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6) + }) : + i18next.t('提示 {{input}} tokens / 1M tokens * ${{price}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6) + }) + } +

+

{i18next.t('仅供参考,以实际扣费为准')}

+
+ + ); + } +} + +export function renderClaudeLogContent( + modelRatio, + completionRatio, + modelPrice = -1, + groupRatio, + cacheRatio = 1.0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + return i18next.t('模型倍率 {{modelRatio}},补全倍率 {{completionRatio}},缓存倍率 {{cacheRatio}},缓存创建倍率 {{cacheCreationRatio}},{{ratioType}} {{ratio}}', { + modelRatio: modelRatio, + completionRatio: completionRatio, + cacheRatio: cacheRatio, + cacheCreationRatio: cacheCreationRatio, + ratioType: ratioLabel, + ratio: groupRatio + }); + } +} + +export function renderClaudeModelPriceSimple( + modelRatio, + modelPrice = -1, + groupRatio, + cacheTokens = 0, + cacheRatio = 1.0, + cacheCreationTokens = 0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组'); + + if (modelPrice !== -1) { + return i18next.t('价格:${{price}} * {{ratioType}}:{{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + if (cacheTokens !== 0 || cacheCreationTokens !== 0) { + return i18next.t('模型: {{ratio}} * {{ratioType}}: {{groupRatio}} * 缓存: {{cacheRatio}}', { + ratio: modelRatio, + ratioType: ratioLabel, + groupRatio: groupRatio, + cacheRatio: cacheRatio, + cacheCreationRatio: cacheCreationRatio + }); + } else { + return i18next.t('模型: {{ratio}} * {{ratioType}}: {{groupRatio}}', { + ratio: modelRatio, + ratioType: ratioLabel, + groupRatio: groupRatio + }); + } + } +} + +export function renderLogContent( + modelRatio, + completionRatio, + modelPrice = -1, + groupRatio +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + return i18next.t('模型倍率 {{modelRatio}},补全倍率 {{completionRatio}},{{ratioType}} {{ratio}}', { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio: groupRatio + }); + } +}