根据 Codex 代码审查报告,修复所有 P0 和 P1 优先级问题。 ## P0 紧急修复 ### 1. 修复集成测试编译错误 - 更新 group_repo_integration_test.go 中所有 ListWithFilters 调用 - 添加缺失的 search 参数(传入空字符串) - 修复 4 处旧签名调用,避免 CI 编译失败 ### 2. 添加统一的 search 参数输入验证 为所有 admin handler 添加一致的输入验证逻辑: - group_handler.go: 添加 TrimSpace + 长度限制 - proxy_handler.go: 添加 TrimSpace + 长度限制 - redeem_handler.go: 添加 TrimSpace + 长度限制 - user_handler.go: 添加 TrimSpace + 长度限制 验证规则: - TrimSpace() 去除首尾空格 - 最大长度 100 字符(防止 DoS 攻击) - 超长输入自动截断 ## P1 改进 ### 3. 补充 search 功能的单元测试 新增 admin_service_group_test.go 中的测试: - TestAdminService_ListGroups_WithSearch - search 参数正常传递到 repository 层 - search 为空字符串时的行为 - search 与其他过滤条件组合使用 新增 admin_service_search_test.go 文件: - 为其他 admin API 添加 search 测试覆盖 - 统一的测试模式和断言 ### 4. 补充 search 功能的集成测试 新增 group_repo_integration_test.go 测试场景: - TestListWithFilters_Search - 搜索 name 字段匹配 - 搜索 description 字段匹配 - 搜索不存在内容(返回空) - 大小写不敏感测试 - 特殊字符转义测试(%、_) - 与其他过滤条件组合 ## 测试结果 - ✅ 编译检查通过 - ✅ 单元测试全部通过 (3/3) - ✅ 集成测试编译通过 - ✅ 所有 service 测试通过 ## 影响范围 修改文件: 8 个 代码变更: +234 行 / -8 行 ## 相关 Issue 解决代码审查中的安全性和稳定性问题: - 防止 DoS 攻击(超长搜索字符串) - 修复测试编译错误(CI 阻塞问题) - 提升测试覆盖率
293 lines
9.5 KiB
Go
293 lines
9.5 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
|
||
type groupRepoStubForAdmin struct {
|
||
created *Group // 记录 Create 调用的参数
|
||
updated *Group // 记录 Update 调用的参数
|
||
getByID *Group // GetByID 返回值
|
||
getErr error // GetByID 返回的错误
|
||
|
||
listWithFiltersCalls int
|
||
listWithFiltersParams pagination.PaginationParams
|
||
listWithFiltersPlatform string
|
||
listWithFiltersStatus string
|
||
listWithFiltersSearch string
|
||
listWithFiltersIsExclusive *bool
|
||
listWithFiltersGroups []Group
|
||
listWithFiltersResult *pagination.PaginationResult
|
||
listWithFiltersErr error
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||
s.created = g
|
||
return nil
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
|
||
s.updated = g
|
||
return nil
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
|
||
if s.getErr != nil {
|
||
return nil, s.getErr
|
||
}
|
||
return s.getByID, nil
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||
panic("unexpected Delete call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||
panic("unexpected DeleteCascade call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||
panic("unexpected List call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||
s.listWithFiltersCalls++
|
||
s.listWithFiltersParams = params
|
||
s.listWithFiltersPlatform = platform
|
||
s.listWithFiltersStatus = status
|
||
s.listWithFiltersSearch = search
|
||
s.listWithFiltersIsExclusive = isExclusive
|
||
|
||
if s.listWithFiltersErr != nil {
|
||
return nil, nil, s.listWithFiltersErr
|
||
}
|
||
|
||
result := s.listWithFiltersResult
|
||
if result == nil {
|
||
result = &pagination.PaginationResult{
|
||
Total: int64(len(s.listWithFiltersGroups)),
|
||
Page: params.Page,
|
||
PageSize: params.PageSize,
|
||
}
|
||
}
|
||
|
||
return s.listWithFiltersGroups, result, nil
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||
panic("unexpected ListActive call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||
panic("unexpected ListActiveByPlatform call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||
panic("unexpected ExistsByName call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||
panic("unexpected GetAccountCount call")
|
||
}
|
||
|
||
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||
}
|
||
|
||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||
repo := &groupRepoStubForAdmin{}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
price1K := 0.10
|
||
price2K := 0.15
|
||
price4K := 0.30
|
||
|
||
input := &CreateGroupInput{
|
||
Name: "test-group",
|
||
Description: "Test group",
|
||
Platform: PlatformAntigravity,
|
||
RateMultiplier: 1.0,
|
||
ImagePrice1K: &price1K,
|
||
ImagePrice2K: &price2K,
|
||
ImagePrice4K: &price4K,
|
||
}
|
||
|
||
group, err := svc.CreateGroup(context.Background(), input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, group)
|
||
|
||
// 验证 repo 收到了正确的字段
|
||
require.NotNil(t, repo.created)
|
||
require.NotNil(t, repo.created.ImagePrice1K)
|
||
require.NotNil(t, repo.created.ImagePrice2K)
|
||
require.NotNil(t, repo.created.ImagePrice4K)
|
||
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
|
||
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
|
||
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
|
||
}
|
||
|
||
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
|
||
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
|
||
repo := &groupRepoStubForAdmin{}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
input := &CreateGroupInput{
|
||
Name: "test-group",
|
||
Description: "Test group",
|
||
Platform: PlatformAntigravity,
|
||
RateMultiplier: 1.0,
|
||
// ImagePrice 字段全部为 nil
|
||
}
|
||
|
||
group, err := svc.CreateGroup(context.Background(), input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, group)
|
||
|
||
// 验证 ImagePrice 字段为 nil
|
||
require.NotNil(t, repo.created)
|
||
require.Nil(t, repo.created.ImagePrice1K)
|
||
require.Nil(t, repo.created.ImagePrice2K)
|
||
require.Nil(t, repo.created.ImagePrice4K)
|
||
}
|
||
|
||
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
|
||
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
|
||
existingGroup := &Group{
|
||
ID: 1,
|
||
Name: "existing-group",
|
||
Platform: PlatformAntigravity,
|
||
Status: StatusActive,
|
||
}
|
||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
price1K := 0.12
|
||
price2K := 0.18
|
||
price4K := 0.36
|
||
|
||
input := &UpdateGroupInput{
|
||
ImagePrice1K: &price1K,
|
||
ImagePrice2K: &price2K,
|
||
ImagePrice4K: &price4K,
|
||
}
|
||
|
||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, group)
|
||
|
||
// 验证 repo 收到了更新后的字段
|
||
require.NotNil(t, repo.updated)
|
||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||
require.NotNil(t, repo.updated.ImagePrice4K)
|
||
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
|
||
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
|
||
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
|
||
}
|
||
|
||
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
|
||
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||
oldPrice2K := 0.15
|
||
existingGroup := &Group{
|
||
ID: 1,
|
||
Name: "existing-group",
|
||
Platform: PlatformAntigravity,
|
||
Status: StatusActive,
|
||
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
|
||
}
|
||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
// 只更新 1K 价格
|
||
price1K := 0.10
|
||
input := &UpdateGroupInput{
|
||
ImagePrice1K: &price1K,
|
||
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
|
||
}
|
||
|
||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, group)
|
||
|
||
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
|
||
require.NotNil(t, repo.updated)
|
||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
|
||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||
require.Nil(t, repo.updated.ImagePrice4K)
|
||
}
|
||
|
||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||
// 测试:
|
||
// 1. search 参数正常传递到 repository 层
|
||
// 2. search 为空字符串时的行为
|
||
// 3. search 与其他过滤条件组合使用
|
||
|
||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||
repo := &groupRepoStubForAdmin{
|
||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||
}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(1), total)
|
||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||
|
||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||
})
|
||
|
||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||
repo := &groupRepoStubForAdmin{
|
||
listWithFiltersGroups: []Group{},
|
||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||
}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||
require.NoError(t, err)
|
||
require.Empty(t, groups)
|
||
require.Equal(t, int64(0), total)
|
||
|
||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||
})
|
||
|
||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||
isExclusive := true
|
||
repo := &groupRepoStubForAdmin{
|
||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||
}
|
||
svc := &adminServiceImpl{groupRepo: repo}
|
||
|
||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(42), total)
|
||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||
|
||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||
})
|
||
}
|