test: 完善自动化测试体系(7个模块,73个任务)
系统性地修复、补充和强化项目的自动化测试能力: 1. 测试基础设施修复 - 修复 stubConcurrencyCache 缺失方法和构造函数参数不匹配 - 创建 testutil 共享包(stubs.go, fixtures.go, httptest.go) - 为所有 Stub 添加编译期接口断言 2. 中间件测试补充 - 新增 JWT 认证中间件测试(有效/过期/篡改/缺失 Token) - 补充 rate_limiter 和 recovery 中间件测试场景 3. 网关核心路径测试 - 新增账户选择、等待队列、流式响应、并发控制、计费、Claude Code 检测测试 - 覆盖负载均衡、粘性会话、SSE 转发、槽位管理等关键逻辑 4. 前端测试体系(11个新测试文件,163个测试用例) - Pinia stores: auth, app, subscriptions - API client: 请求拦截器、响应拦截器、401 刷新 - Router guards: 认证重定向、管理员权限、简易模式限制 - Composables: useForm, useTableLoader, useClipboard - Components: LoginForm, ApiKeyCreate, Dashboard 5. CI/CD 流水线重构 - 重构 backend-ci.yml 为统一的 ci.yml - 前后端 4 个并行 Job + Postgres/Redis services - Race 检测、覆盖率收集与门禁、Docker 构建验证 6. E2E 自动化测试 - e2e-test.sh 自动化脚本(Docker 启动→健康检查→测试→清理) - 用户注册→登录→API Key→网关调用完整链路测试 - Mock 模式和 API Key 脱敏支持 7. 修复预存问题 - tlsfingerprint dialer_test.go 缺失 build tag 导致集成测试编译冲突 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 编译期接口断言
|
||||
var _ HTTPUpstream = (*stubAntigravityUpstream)(nil)
|
||||
var _ HTTPUpstream = (*recordingOKUpstream)(nil)
|
||||
var _ AccountRepository = (*stubAntigravityAccountRepo)(nil)
|
||||
var _ SchedulerCache = (*stubSchedulerCache)(nil)
|
||||
|
||||
type stubAntigravityUpstream struct {
|
||||
firstBase string
|
||||
secondBase string
|
||||
|
||||
310
backend/internal/service/billing_service_test.go
Normal file
310
backend/internal/service/billing_service_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestBillingService() *BillingService {
|
||||
return NewBillingService(&config.Config{}, nil)
|
||||
}
|
||||
|
||||
func TestCalculateCost_BasicComputation(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075
|
||||
expectedInput := 1000 * 3e-6
|
||||
expectedOutput := 500 * 15e-6
|
||||
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_WithCacheTokens(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
CacheCreationTokens: 2000,
|
||||
CacheReadTokens: 3000,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedCacheCreation := 2000 * 3.75e-6
|
||||
expectedCacheRead := 3000 * 0.3e-6
|
||||
require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10)
|
||||
|
||||
expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead
|
||||
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_RateMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TotalCost 不受倍率影响,ActualCost 翻倍
|
||||
require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10)
|
||||
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
model string
|
||||
expectedInput float64
|
||||
}{
|
||||
{"claude-opus-4.5-20250101", 5e-6},
|
||||
{"claude-3-opus-20240229", 15e-6},
|
||||
{"claude-sonnet-4-20250514", 3e-6},
|
||||
{"claude-3-5-sonnet-20241022", 3e-6},
|
||||
{"claude-3-5-haiku-20241022", 1e-6},
|
||||
{"claude-3-haiku-20240307", 0.25e-6},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
pricing, err := svc.GetModelPricing(tt.model)
|
||||
require.NoError(t, err, "模型 %s", tt.model)
|
||||
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_CaseInsensitive(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
p1, err := svc.GetModelPricing("Claude-Sonnet-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
p2, err := svc.GetModelPricing("claude-sonnet-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
|
||||
pricing, err := svc.GetModelPricing("claude-unknown-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 50000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 100000,
|
||||
}
|
||||
// 总输入 150k < 200k 阈值,应走正常计费
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 缓存 210k + 输入 10k = 220k > 200k 阈值
|
||||
// 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 10000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 210000,
|
||||
}
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 范围内:200k cache + 0 input + 1k output
|
||||
inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 200000,
|
||||
}, 1.0)
|
||||
|
||||
// 范围外:10k cache + 10k input,倍率 2.0
|
||||
outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
|
||||
InputTokens: 10000,
|
||||
CacheReadTokens: 10000,
|
||||
}, 2.0)
|
||||
|
||||
require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 缓存 100k + 输入 150k = 250k > 200k 阈值
|
||||
// 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 150000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 100000,
|
||||
}
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, cost.ActualCost > 0, "费用应大于 0")
|
||||
|
||||
// 正常费用不含长上下文
|
||||
normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用")
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
|
||||
|
||||
// threshold <= 0 应禁用长上下文计费
|
||||
cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 300000}
|
||||
|
||||
// extraMultiplier <= 1 应禁用长上下文计费
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateImageCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price := 0.134
|
||||
cfg := &ImagePriceConfig{Price1K: &price}
|
||||
cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0)
|
||||
|
||||
require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price := 0.5
|
||||
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
|
||||
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
|
||||
|
||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
hdPrice := 1.0
|
||||
normalPrice := 0.5
|
||||
cfg := &SoraPriceConfig{
|
||||
VideoPricePerRequest: &normalPrice,
|
||||
VideoPricePerRequestHD: &hdPrice,
|
||||
}
|
||||
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
|
||||
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestIsModelSupported(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
require.True(t, svc.IsModelSupported("claude-sonnet-4"))
|
||||
require.True(t, svc.IsModelSupported("Claude-Opus-4.5"))
|
||||
require.True(t, svc.IsModelSupported("claude-3-haiku"))
|
||||
require.False(t, svc.IsModelSupported("gpt-4o"))
|
||||
require.False(t, svc.IsModelSupported("gemini-pro"))
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroTokens(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
func TestCalculateCost_LargeTokenCount(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1_000_000,
|
||||
OutputTokens: 1_000_000,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15
|
||||
require.InDelta(t, 3.0, cost.InputCost, 1e-6)
|
||||
require.InDelta(t, 15.0, cost.OutputCost, 1e-6)
|
||||
require.False(t, math.IsNaN(cost.TotalCost))
|
||||
require.False(t, math.IsInf(cost.TotalCost, 0))
|
||||
}
|
||||
282
backend/internal/service/claude_code_detection_test.go
Normal file
282
backend/internal/service/claude_code_detection_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestValidator() *ClaudeCodeValidator {
|
||||
return NewClaudeCodeValidator()
|
||||
}
|
||||
|
||||
// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体
|
||||
func validClaudeCodeBody() map[string]any {
|
||||
return map[string]any{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{"标准版本号", "claude-cli/1.0.0", true},
|
||||
{"多位版本号", "claude-cli/12.34.56", true},
|
||||
{"大写开头", "Claude-CLI/1.0.0", true},
|
||||
{"非 claude-cli", "curl/7.64.1", false},
|
||||
{"空 User-Agent", "", false},
|
||||
{"部分匹配", "not-claude-cli/1.0.0", false},
|
||||
{"缺少版本号", "claude-cli/", false},
|
||||
{"版本格式不对", "claude-cli/1.0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_NonMessagesPath_UAOnly(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
// 非 messages 路径只检查 UA
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.True(t, result, "非 messages 路径只需 UA 匹配")
|
||||
}
|
||||
|
||||
func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "curl/7.64.1")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.False(t, result, "UA 不匹配时应返回 false")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_FullValid(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
result := v.Validate(req, validClaudeCodeBody())
|
||||
require.True(t, result, "完整有效请求应通过")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
body := validClaudeCodeBody()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
missingHeader string
|
||||
}{
|
||||
{"缺少 X-App", "X-App"},
|
||||
{"缺少 anthropic-beta", "anthropic-beta"},
|
||||
{"缺少 anthropic-version", "anthropic-version"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Del(tt.missingHeader)
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
}{
|
||||
{"缺少 metadata", nil},
|
||||
{"缺少 user_id", map[string]any{"other": "value"}},
|
||||
{"空 user_id", map[string]any{"user_id": ""}},
|
||||
{"格式错误", map[string]any{"user_id": "invalid-format"}},
|
||||
{"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
},
|
||||
}
|
||||
if tt.metadata != nil {
|
||||
body["metadata"] = tt.metadata
|
||||
}
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "metadata.user_id: %v", tt.metadata)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Generate JSON data for testing database migrations.",
|
||||
},
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc",
|
||||
},
|
||||
}
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "无关系统提示词应返回 false")
|
||||
}
|
||||
|
||||
func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
// 不设置 X-App 等头,通过 context 标记为 haiku 探测请求
|
||||
ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// 即使 body 不包含 system prompt,也应通过
|
||||
result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1})
|
||||
require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证")
|
||||
}
|
||||
|
||||
func TestSystemPromptSimilarity(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prompt string
|
||||
want bool
|
||||
}{
|
||||
{"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true},
|
||||
{"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true},
|
||||
{"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true},
|
||||
{"无关文本", "Write me a poem about cats", false},
|
||||
{"空文本", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{"type": "text", "text": tt.prompt},
|
||||
},
|
||||
}
|
||||
result := v.IncludesClaudeCodeSystemPrompt(body)
|
||||
require.Equal(t, tt.want, result, "提示词: %q", tt.prompt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiceCoefficient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a string
|
||||
b string
|
||||
want float64
|
||||
tol float64
|
||||
}{
|
||||
{"相同字符串", "hello", "hello", 1.0, 0.001},
|
||||
{"完全不同", "abc", "xyz", 0.0, 0.001},
|
||||
{"空字符串", "", "hello", 0.0, 0.001},
|
||||
{"单字符", "a", "b", 0.0, 0.001},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := diceCoefficient(tt.a, tt.b)
|
||||
require.InDelta(t, tt.want, result, tt.tol)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsClaudeCodeClient_Context(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 默认应为 false
|
||||
require.False(t, IsClaudeCodeClient(ctx))
|
||||
|
||||
// 设置为 true
|
||||
ctx = SetClaudeCodeClient(ctx, true)
|
||||
require.True(t, IsClaudeCodeClient(ctx))
|
||||
|
||||
// 设置为 false
|
||||
ctx = SetClaudeCodeClient(ctx, false)
|
||||
require.False(t, IsClaudeCodeClient(ctx))
|
||||
}
|
||||
|
||||
func TestValidate_NilBody_MessagesPath(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.False(t, result, "nil body 的 messages 请求应返回 false")
|
||||
}
|
||||
280
backend/internal/service/concurrency_service_test.go
Normal file
280
backend/internal/service/concurrency_service_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
|
||||
type stubConcurrencyCacheForTest struct {
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
concurrencyErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
usersLoadBatch map[int64]*UserLoadInfo
|
||||
usersLoadErr error
|
||||
cleanupErr error
|
||||
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
releasedRequestIDs []string
|
||||
}
|
||||
|
||||
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
|
||||
return c.acquireResult, c.acquireErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error {
|
||||
c.releasedAccountIDs = append(c.releasedAccountIDs, accountID)
|
||||
c.releasedRequestIDs = append(c.releasedRequestIDs, requestID)
|
||||
return c.releaseErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) {
|
||||
return c.waitCount, c.waitCountErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
|
||||
return c.acquireResult, c.acquireErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error {
|
||||
return c.releaseErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return c.loadBatch, c.loadBatchErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
return c.usersLoadBatch, c.usersLoadErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
require.NotNil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_Failure(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: false}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Acquired)
|
||||
require.Nil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
|
||||
|
||||
for _, maxConcurrency := range []int{0, -1} {
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency)
|
||||
require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_CacheError(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 42, 5)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
|
||||
// 调用 ReleaseFunc 应释放槽位
|
||||
result.ReleaseFunc()
|
||||
|
||||
require.Len(t, cache.releasedAccountIDs, 1)
|
||||
require.Equal(t, int64(42), cache.releasedAccountIDs[0])
|
||||
require.Len(t, cache.releasedRequestIDs, 1)
|
||||
require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空")
|
||||
}
|
||||
|
||||
func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 用户槽位获取应独立于账户槽位
|
||||
result, err := svc.AcquireUserSlot(context.Background(), 100, 3)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
require.NotNil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
|
||||
|
||||
result, err := svc.AcquireUserSlot(context.Background(), 1, 0)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
|
||||
expected := map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
|
||||
2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100},
|
||||
}
|
||||
cache := &stubConcurrencyCacheForTest{loadBatch: expected}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
accounts := []AccountWithConcurrency{
|
||||
{ID: 1, MaxConcurrency: 5},
|
||||
{ID: 2, MaxConcurrency: 5},
|
||||
}
|
||||
result, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
result, err := svc.GetAccountsLoadBatch(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_QueueFull(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: false}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.False(t, allowed)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_FailOpen(t *testing.T) {
|
||||
// Redis 错误时应 fail-open(允许请求通过)
|
||||
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err, "Redis 错误不应传播")
|
||||
require.True(t, allowed, "Redis 错误时应 fail-open")
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed, "nil cache 应 fail-open")
|
||||
}
|
||||
|
||||
func TestCalculateMaxWait(t *testing.T) {
|
||||
tests := []struct {
|
||||
concurrency int
|
||||
expected int
|
||||
}{
|
||||
{5, 25}, // 5 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{10, 30}, // 10 + 20
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := CalculateMaxWait(tt.concurrency)
|
||||
require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountWaitingCount(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitCount: 5}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, count)
|
||||
}
|
||||
|
||||
func TestGetAccountWaitingCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestGetAccountConcurrencyBatch(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{concurrency: 3}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
for _, id := range []int64{1, 2, 3} {
|
||||
require.Equal(t, 3, result[id])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementAccountWaitCount_FailOpen(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err, "Redis 错误不应传播")
|
||||
require.True(t, allowed, "Redis 错误时应 fail-open")
|
||||
}
|
||||
|
||||
func TestIncrementAccountWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
}
|
||||
198
backend/internal/service/gateway_account_selection_test.go
Normal file
198
backend/internal/service/gateway_account_selection_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func testTimePtr(t time.Time) *time.Time { return &t }
|
||||
|
||||
func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad {
|
||||
return accountWithLoad{
|
||||
account: &Account{
|
||||
ID: id,
|
||||
Priority: priority,
|
||||
LastUsedAt: lastUsed,
|
||||
Type: accType,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
},
|
||||
loadInfo: &AccountLoadInfo{
|
||||
AccountID: id,
|
||||
CurrentConcurrency: 0,
|
||||
LoadRate: loadRate,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// --- sortAccountsByPriorityAndLastUsed ---
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) {
|
||||
now := time.Now()
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)},
|
||||
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)},
|
||||
{ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)},
|
||||
}
|
||||
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||
require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一")
|
||||
require.Equal(t, int64(3), accounts[1].ID)
|
||||
require.Equal(t, int64(1), accounts[2].ID)
|
||||
}
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) {
|
||||
now := time.Now()
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)},
|
||||
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))},
|
||||
{ID: 3, Priority: 1, LastUsedAt: nil},
|
||||
}
|
||||
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||
require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前")
|
||||
require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面")
|
||||
require.Equal(t, int64(1), accounts[2].ID)
|
||||
}
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) {
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth},
|
||||
}
|
||||
sortAccountsByPriorityAndLastUsed(accounts, true)
|
||||
require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面")
|
||||
}
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) {
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
|
||||
{ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
|
||||
}
|
||||
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||
// 稳定排序:相同键值的元素保持原始顺序
|
||||
require.Equal(t, int64(1), accounts[0].ID)
|
||||
require.Equal(t, int64(2), accounts[1].ID)
|
||||
require.Equal(t, int64(3), accounts[2].ID)
|
||||
}
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 2, LastUsedAt: nil},
|
||||
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)},
|
||||
{ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))},
|
||||
{ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))},
|
||||
}
|
||||
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||
// 优先级1排前:nil < earlier
|
||||
require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早")
|
||||
require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在")
|
||||
// 优先级2排后:nil < time
|
||||
require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil")
|
||||
require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间")
|
||||
}
|
||||
|
||||
// --- selectByCallCount ---
|
||||
|
||||
func TestSelectByCallCount_Empty(t *testing.T) {
|
||||
result := selectByCallCount(nil, nil, false)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_Single(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey),
|
||||
}
|
||||
result := selectByCallCount(accounts, map[int64]*ModelLoadInfo{1: {CallCount: 10}}, false)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int64(1), result.account.ID)
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_NilModelLoadFallsBackToLRU(t *testing.T) {
|
||||
now := time.Now()
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, testTimePtr(now), AccountTypeAPIKey),
|
||||
makeAccWithLoad(2, 1, 50, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey),
|
||||
}
|
||||
result := selectByCallCount(accounts, nil, false)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int64(2), result.account.ID, "nil modelLoadMap 应回退到 LRU 选择")
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_SelectsMinCallCount(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey),
|
||||
}
|
||||
modelLoad := map[int64]*ModelLoadInfo{
|
||||
1: {CallCount: 100},
|
||||
2: {CallCount: 5},
|
||||
3: {CallCount: 50},
|
||||
}
|
||||
// 运行多次确认总是选调用次数最少的
|
||||
for i := 0; i < 10; i++ {
|
||||
result := selectByCallCount(accounts, modelLoad, false)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int64(2), result.account.ID, "应选择调用次数最少的账号")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_NewAccountUsesAverage(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey),
|
||||
}
|
||||
// 账号1和2有调用记录,账号3是新账号(CallCount=0)
|
||||
// 平均调用次数 = (100 + 200) / 2 = 150
|
||||
// 新账号用平均值 150,比账号1(100)多,所以应选账号1
|
||||
modelLoad := map[int64]*ModelLoadInfo{
|
||||
1: {CallCount: 100},
|
||||
2: {CallCount: 200},
|
||||
// 3 没有记录
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
result := selectByCallCount(accounts, modelLoad, false)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int64(1), result.account.ID, "新账号虚拟调用次数(150)高于账号1(100),应选账号1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_AllNewAccountsFallToAvgZero(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey),
|
||||
}
|
||||
// 所有账号都是新的,avgCallCount = 0,所有人 effectiveCallCount 都是 0
|
||||
modelLoad := map[int64]*ModelLoadInfo{}
|
||||
validIDs := map[int64]bool{1: true, 2: true}
|
||||
for i := 0; i < 10; i++ {
|
||||
result := selectByCallCount(accounts, modelLoad, false)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, validIDs[result.account.ID], "所有新账号应随机选择")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectByCallCount_PreferOAuth(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey),
|
||||
makeAccWithLoad(2, 1, 50, nil, AccountTypeOAuth),
|
||||
}
|
||||
// 两个账号调用次数相同
|
||||
modelLoad := map[int64]*ModelLoadInfo{
|
||||
1: {CallCount: 10},
|
||||
2: {CallCount: 10},
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
result := selectByCallCount(accounts, modelLoad, true)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int64(2), result.account.ID, "调用次数相同时应优先选择 OAuth 账号")
|
||||
}
|
||||
}
|
||||
203
backend/internal/service/gateway_streaming_test.go
Normal file
203
backend/internal/service/gateway_streaming_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- parseSSEUsage 测试 ---
|
||||
|
||||
func newMinimalGatewayService() *GatewayService {
|
||||
return &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_MessageStart(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}`
|
||||
svc.parseSSEUsage(data, usage)
|
||||
|
||||
require.Equal(t, 100, usage.InputTokens)
|
||||
require.Equal(t, 50, usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 200, usage.CacheReadInputTokens)
|
||||
require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens")
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_MessageDelta(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
data := `{"type":"message_delta","usage":{"output_tokens":42}}`
|
||||
svc.parseSSEUsage(data, usage)
|
||||
|
||||
require.Equal(t, 42, usage.OutputTokens)
|
||||
require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens")
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// 先处理 message_start
|
||||
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage)
|
||||
require.Equal(t, 100, usage.InputTokens)
|
||||
|
||||
// 再处理 message_delta(output_tokens > 0, input_tokens = 0)
|
||||
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage)
|
||||
require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值")
|
||||
require.Equal(t, 50, usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// GLM 等 API 会在 delta 中包含所有 usage 信息
|
||||
svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage)
|
||||
require.Equal(t, 200, usage.InputTokens)
|
||||
require.Equal(t, 100, usage.OutputTokens)
|
||||
require.Equal(t, 30, usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 60, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// 无效 JSON 不应 panic
|
||||
svc.parseSSEUsage("not json", usage)
|
||||
require.Equal(t, 0, usage.InputTokens)
|
||||
require.Equal(t, 0, usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_UnknownType(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// 不是 message_start 或 message_delta 的类型
|
||||
svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage)
|
||||
require.Equal(t, 0, usage.InputTokens)
|
||||
require.Equal(t, 0, usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_EmptyString(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
svc.parseSSEUsage("", usage)
|
||||
require.Equal(t, 0, usage.InputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_DoneEvent(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// [DONE] 事件不应影响 usage
|
||||
svc.parseSSEUsage("[DONE]", usage)
|
||||
require.Equal(t, 0, usage.InputTokens)
|
||||
}
|
||||
|
||||
// --- 流式响应端到端测试 ---
|
||||
|
||||
func TestHandleStreamingResponse_CacheTokens(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 10, result.usage.InputTokens)
|
||||
require.Equal(t, 15, result.usage.OutputTokens)
|
||||
require.Equal(t, 20, result.usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 30, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
// 直接关闭,不发送任何事件
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 包含特殊字符的 content_block_delta(引号、换行、Unicode)
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 5, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
|
||||
// 验证响应中包含转发的数据
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
||||
}
|
||||
120
backend/internal/service/gateway_waiting_queue_test.go
Normal file
120
backend/internal/service/gateway_waiting_queue_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic
|
||||
func TestDecrementWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
// 不应 panic
|
||||
svc.DecrementWaitCount(context.Background(), 1)
|
||||
}
|
||||
|
||||
// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播
|
||||
func TestDecrementWaitCount_CacheError(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{}
|
||||
svc := NewConcurrencyService(cache)
|
||||
// DecrementWaitCount 使用 background context,错误只记录日志不传播
|
||||
svc.DecrementWaitCount(context.Background(), 1)
|
||||
}
|
||||
|
||||
// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic
|
||||
func TestDecrementAccountWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
svc.DecrementAccountWaitCount(context.Background(), 1)
|
||||
}
|
||||
|
||||
// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播
|
||||
func TestDecrementAccountWaitCount_CacheError(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{}
|
||||
svc := NewConcurrencyService(cache)
|
||||
svc.DecrementAccountWaitCount(context.Background(), 1)
|
||||
}
|
||||
|
||||
// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程
|
||||
func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 进入等待队列
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
|
||||
// 离开等待队列(不应 panic)
|
||||
svc.DecrementWaitCount(context.Background(), 1)
|
||||
}
|
||||
|
||||
// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程
|
||||
func TestWaitingQueueFlow_AccountLevel(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 进入账号等待队列
|
||||
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
|
||||
// 离开账号等待队列
|
||||
svc.DecrementAccountWaitCount(context.Background(), 42)
|
||||
}
|
||||
|
||||
// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false
|
||||
func TestWaitingQueueFull_Returns429Signal(t *testing.T) {
|
||||
// waitAllowed=false 模拟队列已满
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: false}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 用户级等待队列满
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)")
|
||||
|
||||
// 账号级等待队列满
|
||||
allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err)
|
||||
require.False(t, allowed, "账号等待队列满时应返回 false")
|
||||
}
|
||||
|
||||
// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open
|
||||
func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 用户级:Redis 错误时允许通过
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err, "Redis 错误不应向调用方传播")
|
||||
require.True(t, allowed, "Redis 故障时应 fail-open 放行")
|
||||
|
||||
// 账号级:同样 fail-open
|
||||
allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err, "Redis 错误不应向调用方传播")
|
||||
require.True(t, allowed, "Redis 故障时应 fail-open 放行")
|
||||
}
|
||||
|
||||
// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算
|
||||
func TestCalculateMaxWait_Scenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
concurrency int
|
||||
expected int
|
||||
}{
|
||||
{5, 25}, // 5 + 20
|
||||
{10, 30}, // 10 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{-10, 21}, // min(1) + 20
|
||||
{100, 120}, // 100 + 20
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := CalculateMaxWait(tt.concurrency)
|
||||
require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 编译期接口断言
|
||||
var _ AccountRepository = (*stubOpenAIAccountRepo)(nil)
|
||||
var _ GatewayCache = (*stubGatewayCache)(nil)
|
||||
|
||||
type stubOpenAIAccountRepo struct {
|
||||
AccountRepository
|
||||
accounts []Account
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ OpsRepository = (*stubOpsRepo)(nil)
|
||||
|
||||
type stubOpsRepo struct {
|
||||
OpsRepository
|
||||
overview *OpsDashboardOverview
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ SoraClient = (*stubSoraClientForPoll)(nil)
|
||||
|
||||
type stubSoraClientForPoll struct {
|
||||
imageStatus *SoraImageTaskStatus
|
||||
videoStatus *SoraVideoTaskStatus
|
||||
|
||||
@@ -14,7 +14,7 @@ func newTestSubscriptionService() *SubscriptionService {
|
||||
return &SubscriptionService{}
|
||||
}
|
||||
|
||||
func ptrFloat64(v float64) *float64 { return &v }
|
||||
func ptrFloat64(v float64) *float64 { return &v }
|
||||
func ptrTime(t time.Time) *time.Time { return &t }
|
||||
|
||||
func TestCalculateProgress_BasicFields(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user