diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 6d078f1f..8469e2cb 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -23,6 +23,8 @@ linters: desc: "service must not import repository" - pkg: gorm.io/gorm desc: "service must not import gorm" + - pkg: github.com/redis/go-redis/v9 + desc: "service must not import redis" handler-no-repository: list-mode: original files: @@ -30,6 +32,10 @@ linters: deny: - pkg: github.com/Wei-Shaw/sub2api/internal/repository desc: "handler must not import repository" + - pkg: gorm.io/gorm + desc: "handler must not import gorm" + - pkg: github.com/redis/go-redis/v9 + desc: "handler must not import redis" errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Such cases aren't reported by default. diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index a33382ec..84565b47 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "errors" "fmt" "time" @@ -14,6 +15,11 @@ const ( apiKeyRateLimitDuration = 24 * time.Hour ) +// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. +func apiKeyRateLimitKey(userID int64) string { + return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) +} + type apiKeyCache struct { rdb *redis.Client } @@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache { } func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { - key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) - return c.rdb.Get(ctx, key).Int() + key := apiKeyRateLimitKey(userID) + count, err := c.rdb.Get(ctx, key).Int() + if errors.Is(err, redis.Nil) { + return 0, nil + } + return count, err } func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + key := apiKeyRateLimitKey(userID) pipe := c.rdb.Pipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, apiKeyRateLimitDuration) @@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in } func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + key := apiKeyRateLimitKey(userID) return c.rdb.Del(ctx, key).Err() } diff --git a/backend/internal/repository/api_key_cache_integration_test.go b/backend/internal/repository/api_key_cache_integration_test.go index 6fcd0dfd..e9394917 100644 --- a/backend/internal/repository/api_key_cache_integration_test.go +++ b/backend/internal/repository/api_key_cache_integration_test.go @@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() { fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) }{ { - name: "missing_key_returns_redis_nil", + name: "missing_key_returns_zero_nil", fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { userID := int64(1) - _, err := cache.GetCreateAttemptCount(ctx, userID) + count, err := cache.GetCreateAttemptCount(ctx, userID) - require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key") + require.NoError(s.T(), err, "expected nil error for missing key") + require.Equal(s.T(), 0, count, "expected zero count for missing key") }, }, { @@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() { require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID)) require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount") - _, err := cache.GetCreateAttemptCount(ctx, userID) - require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete") + count, err := cache.GetCreateAttemptCount(ctx, userID) + require.NoError(s.T(), err, "expected nil error after delete") + require.Equal(s.T(), 0, count, "expected zero count after delete") }, }, } diff --git a/backend/internal/repository/api_key_cache_test.go b/backend/internal/repository/api_key_cache_test.go new file mode 100644 index 00000000..7ad84ba2 --- /dev/null +++ b/backend/internal/repository/api_key_cache_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package repository + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApiKeyRateLimitKey(t *testing.T) { + tests := []struct { + name string + userID int64 + expected string + }{ + { + name: "normal_user_id", + userID: 123, + expected: "apikey:ratelimit:123", + }, + { + name: "zero_user_id", + userID: 0, + expected: "apikey:ratelimit:0", + }, + { + name: "negative_user_id", + userID: -1, + expected: "apikey:ratelimit:-1", + }, + { + name: "max_int64", + userID: math.MaxInt64, + expected: "apikey:ratelimit:9223372036854775807", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := apiKeyRateLimitKey(tc.userID) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 26d789d1..ac5803a1 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -18,6 +18,16 @@ const ( billingCacheTTL = 5 * time.Minute ) +// billingBalanceKey generates the Redis key for user balance cache. +func billingBalanceKey(userID int64) string { + return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) +} + +// billingSubKey generates the Redis key for subscription cache. +func billingSubKey(userID, groupID int64) string { + return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) +} + const ( subFieldStatus = "status" subFieldExpiresAt = "expires_at" @@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache { } func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) { - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + key := billingBalanceKey(userID) val, err := c.rdb.Get(ctx, key).Result() if err != nil { return 0, err @@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 } func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + key := billingBalanceKey(userID) return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + key := billingBalanceKey(userID) _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) @@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou } func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + key := billingBalanceKey(userID) return c.rdb.Del(ctx, key).Err() } func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) { - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + key := billingSubKey(userID, groupID) result, err := c.rdb.HGetAll(ctx, key).Result() if err != nil { return nil, err @@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID return nil } - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + key := billingSubKey(userID, groupID) fields := map[string]any{ subFieldStatus: data.Status, @@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + key := billingSubKey(userID, groupID) _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) @@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou } func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + key := billingSubKey(userID, groupID) return c.rdb.Del(ctx, key).Err() } diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go new file mode 100644 index 00000000..7d3fd19d --- /dev/null +++ b/backend/internal/repository/billing_cache_test.go @@ -0,0 +1,87 @@ +//go:build unit + +package repository + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBillingBalanceKey(t *testing.T) { + tests := []struct { + name string + userID int64 + expected string + }{ + { + name: "normal_user_id", + userID: 123, + expected: "billing:balance:123", + }, + { + name: "zero_user_id", + userID: 0, + expected: "billing:balance:0", + }, + { + name: "negative_user_id", + userID: -1, + expected: "billing:balance:-1", + }, + { + name: "max_int64", + userID: math.MaxInt64, + expected: "billing:balance:9223372036854775807", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := billingBalanceKey(tc.userID) + require.Equal(t, tc.expected, got) + }) + } +} + +func TestBillingSubKey(t *testing.T) { + tests := []struct { + name string + userID int64 + groupID int64 + expected string + }{ + { + name: "normal_ids", + userID: 123, + groupID: 456, + expected: "billing:sub:123:456", + }, + { + name: "zero_ids", + userID: 0, + groupID: 0, + expected: "billing:sub:0:0", + }, + { + name: "negative_ids", + userID: -1, + groupID: -2, + expected: "billing:sub:-1:-2", + }, + { + name: "max_int64_ids", + userID: math.MaxInt64, + groupID: math.MaxInt64, + expected: "billing:sub:9223372036854775807:9223372036854775807", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := billingSubKey(tc.userID, tc.groupID) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index d6cb5c01..e00e35dd 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -11,6 +11,11 @@ import ( const verifyCodeKeyPrefix = "verify_code:" +// verifyCodeKey generates the Redis key for email verification code. +func verifyCodeKey(email string) string { + return verifyCodeKeyPrefix + email +} + type emailCache struct { rdb *redis.Client } @@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache { } func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) { - key := verifyCodeKeyPrefix + email + key := verifyCodeKey(email) val, err := c.rdb.Get(ctx, key).Result() if err != nil { return nil, err @@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se } func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error { - key := verifyCodeKeyPrefix + email + key := verifyCodeKey(email) val, err := json.Marshal(data) if err != nil { return err @@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data } func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error { - key := verifyCodeKeyPrefix + email + key := verifyCodeKey(email) return c.rdb.Del(ctx, key).Err() } diff --git a/backend/internal/repository/email_cache_test.go b/backend/internal/repository/email_cache_test.go new file mode 100644 index 00000000..1c498938 --- /dev/null +++ b/backend/internal/repository/email_cache_test.go @@ -0,0 +1,45 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVerifyCodeKey(t *testing.T) { + tests := []struct { + name string + email string + expected string + }{ + { + name: "normal_email", + email: "user@example.com", + expected: "verify_code:user@example.com", + }, + { + name: "empty_email", + email: "", + expected: "verify_code:", + }, + { + name: "email_with_plus", + email: "user+tag@example.com", + expected: "verify_code:user+tag@example.com", + }, + { + name: "email_with_special_chars", + email: "user.name+tag@sub.domain.com", + expected: "verify_code:user.name+tag@sub.domain.com", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := verifyCodeKey(tc.email) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index 9c776d9c..d28477b7 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -15,6 +15,11 @@ const ( fingerprintTTL = 24 * time.Hour ) +// fingerprintKey generates the Redis key for account fingerprint cache. +func fingerprintKey(accountID int64) string { + return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) +} + type identityCache struct { rdb *redis.Client } @@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache { } func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) { - key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) + key := fingerprintKey(accountID) val, err := c.rdb.Get(ctx, key).Result() if err != nil { return nil, err @@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s } func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error { - key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) + key := fingerprintKey(accountID) val, err := json.Marshal(fp) if err != nil { return err diff --git a/backend/internal/repository/identity_cache_test.go b/backend/internal/repository/identity_cache_test.go new file mode 100644 index 00000000..05921b12 --- /dev/null +++ b/backend/internal/repository/identity_cache_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package repository + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFingerprintKey(t *testing.T) { + tests := []struct { + name string + accountID int64 + expected string + }{ + { + name: "normal_account_id", + accountID: 123, + expected: "fingerprint:123", + }, + { + name: "zero_account_id", + accountID: 0, + expected: "fingerprint:0", + }, + { + name: "negative_account_id", + accountID: -1, + expected: "fingerprint:-1", + }, + { + name: "max_int64", + accountID: math.MaxInt64, + expected: "fingerprint:9223372036854775807", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := fingerprintKey(tc.accountID) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/backend/internal/repository/redeem_cache.go b/backend/internal/repository/redeem_cache.go index e0330b58..831aaf57 100644 --- a/backend/internal/repository/redeem_cache.go +++ b/backend/internal/repository/redeem_cache.go @@ -15,6 +15,16 @@ const ( redeemRateLimitDuration = 24 * time.Hour ) +// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting. +func redeemRateLimitKey(userID int64) string { + return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) +} + +// redeemLockKey generates the Redis key for redeem code locking. +func redeemLockKey(code string) string { + return redeemLockKeyPrefix + code +} + type redeemCache struct { rdb *redis.Client } @@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache { } func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) { - key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) - return c.rdb.Get(ctx, key).Int() + key := redeemRateLimitKey(userID) + count, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + return count, err } func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) + key := redeemRateLimitKey(userID) pipe := c.rdb.Pipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, redeemRateLimitDuration) @@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in } func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) { - key := redeemLockKeyPrefix + code + key := redeemLockKey(code) return c.rdb.SetNX(ctx, key, 1, ttl).Result() } func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error { - key := redeemLockKeyPrefix + code + key := redeemLockKey(code) return c.rdb.Del(ctx, key).Err() } diff --git a/backend/internal/repository/redeem_cache_integration_test.go b/backend/internal/repository/redeem_cache_integration_test.go index a7aa05d9..6398a801 100644 --- a/backend/internal/repository/redeem_cache_integration_test.go +++ b/backend/internal/repository/redeem_cache_integration_test.go @@ -3,12 +3,10 @@ package repository import ( - "errors" "fmt" "testing" "time" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() { func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() { missingUserID := int64(99999) - _, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID) - require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key") - require.True(s.T(), errors.Is(err, redis.Nil)) + count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID) + require.NoError(s.T(), err, "expected nil error for missing rate-limit key") + require.Equal(s.T(), 0, count, "expected zero count for missing key") } func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() { diff --git a/backend/internal/repository/redeem_cache_test.go b/backend/internal/repository/redeem_cache_test.go new file mode 100644 index 00000000..9b547b74 --- /dev/null +++ b/backend/internal/repository/redeem_cache_test.go @@ -0,0 +1,77 @@ +//go:build unit + +package repository + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRedeemRateLimitKey(t *testing.T) { + tests := []struct { + name string + userID int64 + expected string + }{ + { + name: "normal_user_id", + userID: 123, + expected: "redeem:ratelimit:123", + }, + { + name: "zero_user_id", + userID: 0, + expected: "redeem:ratelimit:0", + }, + { + name: "negative_user_id", + userID: -1, + expected: "redeem:ratelimit:-1", + }, + { + name: "max_int64", + userID: math.MaxInt64, + expected: "redeem:ratelimit:9223372036854775807", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := redeemRateLimitKey(tc.userID) + require.Equal(t, tc.expected, got) + }) + } +} + +func TestRedeemLockKey(t *testing.T) { + tests := []struct { + name string + code string + expected string + }{ + { + name: "normal_code", + code: "ABC123", + expected: "redeem:lock:ABC123", + }, + { + name: "empty_code", + code: "", + expected: "redeem:lock:", + }, + { + name: "code_with_special_chars", + code: "CODE-2024:test", + expected: "redeem:lock:CODE-2024:test", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := redeemLockKey(tc.code) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ac236175..4ab50fb5 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "time" @@ -12,7 +11,6 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" - "github.com/redis/go-redis/v9" ) var ( @@ -143,7 +141,7 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) } count, err := s.cache.GetCreateAttemptCount(ctx, userID) - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil { // Redis 出错时不阻止用户操作 return nil } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 144f2c50..c587d212 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -11,7 +11,6 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/redis/go-redis/v9" ) var ( @@ -163,7 +162,7 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) } count, err := s.cache.GetRedeemAttemptCount(ctx, userID) - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil { // Redis 出错时不阻止用户操作 return nil }