基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题: P0 生产 Bug: - 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03) - generateRandomID 改用 crypto/rand 替代固定索引 (P0-04) - IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05) 安全加固: - gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14) - usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16) - 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15) 性能优化: - gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02) - group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03) - usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07) - GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10) - ip.go CIDR 预编译为包级变量 (P1-11) - BatchUpdateCredentials 重构为先验证后更新 (P1-13) 缓存一致性: - billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10) - DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12) - UserService.UpdateBalance 成功后异步失效 billingCache (P2-13) 代码质量: - search 截断改为按 rune 处理,支持多字节字符 (P2-01) - TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07) - CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16) 测试覆盖: - 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例) - 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例) - 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go - 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试 - config_test.go 补充 server.mode/sslmode 默认值断言 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
368 lines
12 KiB
Go
368 lines
12 KiB
Go
//go:build integration
|
||
|
||
package repository
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/stretchr/testify/suite"
|
||
)
|
||
|
||
type BillingCacheSuite struct {
|
||
IntegrationRedisSuite
|
||
}
|
||
|
||
func (s *BillingCacheSuite) TestUserBalance() {
|
||
tests := []struct {
|
||
name string
|
||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||
}{
|
||
{
|
||
name: "missing_key_returns_redis_nil",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
_, err := cache.GetUserBalance(ctx, 1)
|
||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||
},
|
||
},
|
||
{
|
||
name: "deduct_on_nonexistent_is_noop",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(1)
|
||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||
|
||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||
|
||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||
},
|
||
},
|
||
{
|
||
name: "set_and_get_with_ttl",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(2)
|
||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||
|
||
got, err := cache.GetUserBalance(ctx, userID)
|
||
require.NoError(s.T(), err, "GetUserBalance")
|
||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||
|
||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||
require.NoError(s.T(), err, "TTL")
|
||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||
},
|
||
},
|
||
{
|
||
name: "deduct_reduces_balance",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(3)
|
||
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||
|
||
got, err := cache.GetUserBalance(ctx, userID)
|
||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||
},
|
||
},
|
||
{
|
||
name: "invalidate_removes_key",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(100)
|
||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||
|
||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||
require.NoError(s.T(), err, "Exists")
|
||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||
|
||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||
|
||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||
require.NoError(s.T(), err, "Exists after invalidate")
|
||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||
|
||
_, err = cache.GetUserBalance(ctx, userID)
|
||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||
},
|
||
},
|
||
{
|
||
name: "deduct_refreshes_ttl",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(103)
|
||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||
|
||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||
require.NoError(s.T(), err, "TTL before deduct")
|
||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||
|
||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||
|
||
balance, err := cache.GetUserBalance(ctx, userID)
|
||
require.NoError(s.T(), err, "GetUserBalance")
|
||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||
|
||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||
require.NoError(s.T(), err, "TTL after deduct")
|
||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
s.Run(tt.name, func() {
|
||
rdb := testRedis(s.T())
|
||
cache := NewBillingCache(rdb)
|
||
ctx := context.Background()
|
||
|
||
tt.fn(ctx, rdb, cache)
|
||
})
|
||
}
|
||
}
|
||
|
||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||
tests := []struct {
|
||
name string
|
||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||
}{
|
||
{
|
||
name: "missing_key_returns_redis_nil",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(10)
|
||
groupID := int64(20)
|
||
|
||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||
},
|
||
},
|
||
{
|
||
name: "update_usage_on_nonexistent_is_noop",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(11)
|
||
groupID := int64(21)
|
||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||
|
||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||
|
||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||
require.NoError(s.T(), err, "Exists")
|
||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||
},
|
||
},
|
||
{
|
||
name: "set_and_get_with_ttl",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(12)
|
||
groupID := int64(22)
|
||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||
|
||
data := &service.SubscriptionCacheData{
|
||
Status: "active",
|
||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||
DailyUsage: 1.0,
|
||
WeeklyUsage: 2.0,
|
||
MonthlyUsage: 3.0,
|
||
Version: 7,
|
||
}
|
||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||
|
||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||
require.Equal(s.T(), "active", gotSub.Status)
|
||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||
|
||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||
require.NoError(s.T(), err, "TTL subKey")
|
||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||
},
|
||
},
|
||
{
|
||
name: "update_usage_increments_all_fields",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(13)
|
||
groupID := int64(23)
|
||
|
||
data := &service.SubscriptionCacheData{
|
||
Status: "active",
|
||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||
DailyUsage: 1.0,
|
||
WeeklyUsage: 2.0,
|
||
MonthlyUsage: 3.0,
|
||
Version: 1,
|
||
}
|
||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||
|
||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||
|
||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||
},
|
||
},
|
||
{
|
||
name: "invalidate_removes_key",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(101)
|
||
groupID := int64(10)
|
||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||
|
||
data := &service.SubscriptionCacheData{
|
||
Status: "active",
|
||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||
DailyUsage: 1.0,
|
||
WeeklyUsage: 2.0,
|
||
MonthlyUsage: 3.0,
|
||
Version: 1,
|
||
}
|
||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||
|
||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||
require.NoError(s.T(), err, "Exists")
|
||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||
|
||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||
|
||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||
require.NoError(s.T(), err, "Exists after invalidate")
|
||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||
|
||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||
},
|
||
},
|
||
{
|
||
name: "missing_status_returns_parsing_error",
|
||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||
userID := int64(102)
|
||
groupID := int64(11)
|
||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||
|
||
fields := map[string]any{
|
||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||
"daily_usage": 1.0,
|
||
"weekly_usage": 2.0,
|
||
"monthly_usage": 3.0,
|
||
"version": 1,
|
||
}
|
||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||
|
||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
require.Error(s.T(), err, "expected error for missing status field")
|
||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
s.Run(tt.name, func() {
|
||
rdb := testRedis(s.T())
|
||
cache := NewBillingCache(rdb)
|
||
ctx := context.Background()
|
||
|
||
tt.fn(ctx, rdb, cache)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||
tests := []struct {
|
||
name string
|
||
fn func(ctx context.Context, cache service.BillingCache)
|
||
expectErr bool
|
||
}{
|
||
{
|
||
name: "key_not_exists_returns_nil",
|
||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||
},
|
||
},
|
||
{
|
||
name: "existing_key_deducts_successfully",
|
||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||
|
||
bal, err := cache.GetUserBalance(ctx, 200)
|
||
require.NoError(s.T(), err)
|
||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||
},
|
||
},
|
||
{
|
||
name: "cancelled_context_propagates_error",
|
||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||
|
||
cancelCtx, cancel := context.WithCancel(ctx)
|
||
cancel() // 立即取消
|
||
|
||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
s.Run(tt.name, func() {
|
||
rdb := testRedis(s.T())
|
||
cache := NewBillingCache(rdb)
|
||
ctx := context.Background()
|
||
tt.fn(ctx, cache)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||
s.Run("key_not_exists_returns_nil", func() {
|
||
rdb := testRedis(s.T())
|
||
cache := NewBillingCache(rdb)
|
||
ctx := context.Background()
|
||
|
||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||
})
|
||
|
||
s.Run("cancelled_context_propagates_error", func() {
|
||
rdb := testRedis(s.T())
|
||
cache := NewBillingCache(rdb)
|
||
ctx := context.Background()
|
||
|
||
data := &service.SubscriptionCacheData{
|
||
Status: "active",
|
||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||
Version: 1,
|
||
}
|
||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||
|
||
cancelCtx, cancel := context.WithCancel(ctx)
|
||
cancel()
|
||
|
||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||
})
|
||
}
|
||
|
||
func TestBillingCacheSuite(t *testing.T) {
|
||
suite.Run(t, new(BillingCacheSuite))
|
||
}
|