Files
xinghuoapi/backend/internal/repository/billing_cache.go
shaw b8b5cec35c fix: resolve CI lint errors and test compilation failures for rate limit feature
- Fix errcheck: properly handle rows.Close() error via named return + defer closure
- Fix gofmt: auto-format billing_cache.go, api_key_service.go, billing_cache_service.go
- Add missing rate limit interface methods to 4 test stubs (GetRateLimitData, IncrementRateLimitUsage, ResetRateLimitWindows)
- Fix NewBillingCacheService calls missing the new apiKeyRepo parameter
2026-03-03 15:43:08 +08:00

295 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"errors"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingRateLimitKeyPrefix = "apikey:rate:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
)
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
func jitteredTTL() time.Duration {
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL避免上界预期被打破
if billingCacheJitter <= 0 {
return billingCacheTTL
}
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
return billingCacheTTL - jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func billingRateLimitKey(keyID int64) string {
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
}
const (
rateLimitFieldUsage5h = "usage_5h"
rateLimitFieldUsage1d = "usage_1d"
rateLimitFieldUsage7d = "usage_7d"
rateLimitFieldWindow5h = "window_5h"
rateLimitFieldWindow1d = "window_1d"
rateLimitFieldWindow7d = "window_7d"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) service.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
result := &service.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := billingSubKey(userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, jitteredTTL())
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
key := billingRateLimitKey(keyID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
data := &service.APIKeyRateLimitCacheData{}
if v, ok := result[rateLimitFieldUsage5h]; ok {
data.Usage5h, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage1d]; ok {
data.Usage1d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage7d]; ok {
data.Usage7d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldWindow5h]; ok {
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow1d]; ok {
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow7d]; ok {
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
}
return data, nil
}
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
if data == nil {
return nil
}
key := billingRateLimitKey(keyID)
fields := map[string]any{
rateLimitFieldUsage5h: data.Usage5h,
rateLimitFieldUsage1d: data.Usage1d,
rateLimitFieldUsage7d: data.Usage7d,
rateLimitFieldWindow5h: data.Window5h,
rateLimitFieldWindow1d: data.Window1d,
rateLimitFieldWindow7d: data.Window7d,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, rateLimitCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}