refactor: 封装 Redis key 生成函数

This commit is contained in:
Forest
2025-12-26 16:33:20 +08:00
parent e5a77853b0
commit 06d5876b02
15 changed files with 385 additions and 37 deletions

View File

@@ -23,6 +23,8 @@ linters:
desc: "service must not import repository"
- pkg: gorm.io/gorm
desc: "service must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "service must not import redis"
handler-no-repository:
list-mode: original
files:
@@ -30,6 +32,10 @@ linters:
deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository"
- pkg: gorm.io/gorm
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"errors"
"fmt"
"time"
@@ -14,6 +15,11 @@ const (
apiKeyRateLimitDuration = 24 * time.Hour
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
type apiKeyCache struct {
rdb *redis.Client
}
@@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
key := apiKeyRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
@@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_redis_nil",
name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
_, err := cache.GetCreateAttemptCount(ctx, userID)
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
},
},
{
@@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
_, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
},
},
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -18,6 +18,16 @@ const (
billingCacheTTL = 5 * time.Minute
)
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
@@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache {
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
@@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
@@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
@@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
@@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
@@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,87 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -11,6 +11,11 @@ import (
const verifyCodeKeyPrefix = "verify_code:"
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
@@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache {
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
key := verifyCodeKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
@@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
key := verifyCodeKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
@@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,45 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -15,6 +15,11 @@ const (
fingerprintTTL = 24 * time.Hour
)
// fingerprintKey generates the Redis key for account fingerprint cache.
func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
type identityCache struct {
rdb *redis.Client
}
@@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
key := fingerprintKey(accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
@@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
key := fingerprintKey(accountID)
val, err := json.Marshal(fp)
if err != nil {
return err

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestFingerprintKey(t *testing.T) {
tests := []struct {
name string
accountID int64
expected string
}{
{
name: "normal_account_id",
accountID: 123,
expected: "fingerprint:123",
},
{
name: "zero_account_id",
accountID: 0,
expected: "fingerprint:0",
},
{
name: "negative_account_id",
accountID: -1,
expected: "fingerprint:-1",
},
{
name: "max_int64",
accountID: math.MaxInt64,
expected: "fingerprint:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := fingerprintKey(tc.accountID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -15,6 +15,16 @@ const (
redeemRateLimitDuration = 24 * time.Hour
)
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
func redeemRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
}
// redeemLockKey generates the Redis key for redeem code locking.
func redeemLockKey(code string) string {
return redeemLockKeyPrefix + code
}
type redeemCache struct {
rdb *redis.Client
}
@@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
key := redeemRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
return count, err
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
key := redeemRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
@@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code
key := redeemLockKey(code)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code
key := redeemLockKey(code)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -3,12 +3,10 @@
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() {
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil))
count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
}
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {

View File

@@ -0,0 +1,77 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedeemRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "redeem:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "redeem:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "redeem:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "redeem:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestRedeemLockKey(t *testing.T) {
tests := []struct {
name string
code string
expected string
}{
{
name: "normal_code",
code: "ABC123",
expected: "redeem:lock:ABC123",
},
{
name: "empty_code",
code: "",
expected: "redeem:lock:",
},
{
name: "code_with_special_chars",
code: "CODE-2024:test",
expected: "redeem:lock:CODE-2024:test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemLockKey(tc.code)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
@@ -12,7 +11,6 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
)
var (
@@ -143,7 +141,7 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
}
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}

View File

@@ -11,7 +11,6 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
)
var (
@@ -163,7 +162,7 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
}
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}