112 lines
4.4 KiB
Go
112 lines
4.4 KiB
Go
package claude
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/setting/model_setting"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
|
|
originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
|
|
usage := &dto.ClaudeUsage{
|
|
InputTokens: 100,
|
|
CacheReadInputTokens: 30,
|
|
CacheCreationInputTokens: 50,
|
|
}
|
|
|
|
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
|
|
|
require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
|
|
require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
|
|
require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
|
|
require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
|
|
require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
|
|
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
|
require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
|
|
}
|
|
|
|
func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
|
|
originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
|
|
usage := &dto.ClaudeUsage{
|
|
InputTokens: 100,
|
|
CacheReadInputTokens: 30,
|
|
CacheCreationInputTokens: 0,
|
|
}
|
|
|
|
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
|
|
|
require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
|
|
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
|
assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
|
|
}
|
|
|
|
func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
|
|
originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
|
t.Cleanup(func() {
|
|
model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
|
|
})
|
|
|
|
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
|
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
|
|
|
|
model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
|
|
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
|
|
}))
|
|
assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
|
|
}))
|
|
}
|
|
|
|
func TestBuildMessageDeltaPatchUsage(t *testing.T) {
|
|
t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
|
|
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
|
|
claudeInfo := &ClaudeResponseInfo{
|
|
Usage: &dto.Usage{
|
|
PromptTokens: 100,
|
|
PromptTokensDetails: dto.InputTokenDetails{
|
|
CachedTokens: 30,
|
|
CachedCreationTokens: 50,
|
|
},
|
|
ClaudeCacheCreation5mTokens: 10,
|
|
ClaudeCacheCreation1hTokens: 20,
|
|
},
|
|
}
|
|
|
|
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
|
require.NotNil(t, usage)
|
|
require.EqualValues(t, 100, usage.InputTokens)
|
|
require.EqualValues(t, 30, usage.CacheReadInputTokens)
|
|
require.EqualValues(t, 50, usage.CacheCreationInputTokens)
|
|
require.EqualValues(t, 53, usage.OutputTokens)
|
|
require.NotNil(t, usage.CacheCreation)
|
|
require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
|
|
require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
|
|
})
|
|
|
|
t.Run("keep upstream non-zero values", func(t *testing.T) {
|
|
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
|
|
InputTokens: 9,
|
|
CacheReadInputTokens: 7,
|
|
CacheCreationInputTokens: 6,
|
|
}}
|
|
claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
|
|
PromptTokens: 100,
|
|
PromptTokensDetails: dto.InputTokenDetails{
|
|
CachedTokens: 30,
|
|
CachedCreationTokens: 50,
|
|
},
|
|
}}
|
|
|
|
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
|
require.EqualValues(t, 9, usage.InputTokens)
|
|
require.EqualValues(t, 7, usage.CacheReadInputTokens)
|
|
require.EqualValues(t, 6, usage.CacheCreationInputTokens)
|
|
})
|
|
}
|