refactor(backend): 添加 service 缓存端口

This commit is contained in:
Forest
2025-12-19 23:39:28 +08:00
parent ef81aeb463
commit 7bbf621490
27 changed files with 906 additions and 468 deletions

View File

@@ -41,7 +41,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailService := service.NewEmailService(settingRepository, client)
emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileService := service.NewTurnstileService(settingService)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
@@ -51,15 +52,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
apiKeyCache := repository.NewApiKeyCache(client)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
accountRepository := repository.NewAccountRepository(db)
@@ -81,14 +85,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingService, err := service.ProvidePricingService(configConfig)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client)
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)

View File

@@ -6,6 +6,7 @@ import (
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -19,8 +20,9 @@ type SystemHandler struct {
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
updateCache := repository.NewUpdateCache(rdb)
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
updateSvc: service.NewUpdateService(updateCache, version, buildType),
}
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}

View File

@@ -0,0 +1,174 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,132 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
accountConcurrencyKeyPrefix = "concurrency:account:"
userConcurrencyKeyPrefix = "concurrency:user:"
waitQueueKeyPrefix = "concurrency:wait:"
concurrencyTTL = 5 * time.Minute
)
var (
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"context"
"encoding/json"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data ports.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -0,0 +1,47 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
)
type identityCache struct {
rdb *redis.Client
}
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
return &identityCache{rdb: rdb}
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var fp ports.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err
}
return &fp, nil
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := json.Marshal(fp)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}

View File

@@ -0,0 +1,49 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemRateLimitDuration = 24 * time.Hour
)
type redeemCache struct {
rdb *redis.Client
}
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
return &redeemCache{rdb: rdb}
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const updateCacheKey = "update:latest"
type updateCache struct {
rdb *redis.Client
}
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
return &updateCache{rdb: rdb}
}
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
return c.rdb.Get(ctx, updateCacheKey).Result()
}
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
}

View File

@@ -19,6 +19,16 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewConcurrencyCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
NewUpdateCache,
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),

View File

@@ -27,9 +27,7 @@ var (
)
const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
apiKeyMaxErrorsPerHour = 20
)
// CreateApiKeyRequest 创建API Key请求
@@ -52,7 +50,7 @@ type ApiKeyService struct {
userRepo ports.UserRepository
groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client
cache ports.ApiKeyCache
cfg *config.Config
}
@@ -62,7 +60,7 @@ func NewApiKeyService(
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cache ports.ApiKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
@@ -70,7 +68,7 @@ func NewApiKeyService(
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cache: cache,
cfg: cfg,
}
}
@@ -113,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
@@ -134,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
@@ -273,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
}
// 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil {
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
@@ -326,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
}
@@ -355,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// 清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
@@ -400,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.rdb != nil {
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
}
return nil
}

View File

@@ -5,30 +5,10 @@ import (
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// 订阅缓存Hash字段
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
// 错误定义
@@ -38,35 +18,6 @@ var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
)
// 预编译的Lua脚本
var (
// deductBalanceScript: 扣减余额缓存key不存在则忽略
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateSubUsageScript: 更新订阅用量缓存key不存在则忽略
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct {
Status string
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
rdb *redis.Client
cache ports.BillingCache
userRepo ports.UserRepository
subRepo ports.UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{
rdb: rdb,
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
}
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, su
// GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.rdb == nil {
if s.cache == nil {
// Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID)
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 尝试从缓存读取
val, err := s.rdb.Get(ctx, key).Result()
balance, err := s.cache.GetUserBalance(ctx, userID)
if err == nil {
balance, parseErr := strconv.ParseFloat(val, 64)
if parseErr == nil {
return balance, nil
}
return balance, nil
}
// 缓存未命中或解析错误,从数据库读取
balance, err := s.getUserBalanceFromDB(ctx, userID)
// 缓存未命中,从数据库读取
balance, err = s.getUserBalanceFromDB(ctx, userID)
if err != nil {
return 0, err
}
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
// setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
}
}
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 使用预编译的Lua脚本原子性扣减如果key不存在则忽略
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
return s.cache.DeductUserBalance(ctx, userID, amount)
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.rdb == nil {
if s.cache == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID)
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 尝试从缓存读取
result, err := s.rdb.HGetAll(ctx, key).Result()
if err == nil && len(result) > 0 {
data, parseErr := s.parseSubscriptionCache(result)
if parseErr == nil {
return data, nil
}
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
if err == nil && cacheData != nil {
return s.convertFromPortsData(cacheData), nil
}
// 缓存未命中,从数据库读取
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
return data, nil
}
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
return &subscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
return &ports.SubscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
// getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
}, nil
}
// parseSubscriptionCache 解析订阅缓存数据
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
result := &subscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
// setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.rdb == nil || data == nil {
if s.cache == nil || data == nil {
return
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := s.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
if _, err := pipe.Exec(ctx); err != nil {
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 使用预编译的Lua脚本原子性增加用量如果key不存在则忽略
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}

View File

@@ -2,22 +2,13 @@ package service
import (
"context"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
"sub2api/internal/service/ports"
)
const (
// Redis key prefixes
accountConcurrencyKey = "concurrency:account:"
userConcurrencyKey = "concurrency:user:"
userWaitCountKey = "concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL = 10 * time.Minute
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
@@ -28,70 +19,14 @@ const (
defaultExtraWaitSlots = 20
)
// Pre-compiled Lua scripts for better performance
var (
// acquireScript: increment counter if below max, return 1 if successful
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
// releaseScript: decrement counter, but don't go below 0
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct {
rdb *redis.Client
cache ports.ConcurrencyCache
}
// NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
return &ConcurrencyService{rdb: rdb}
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache}
}
// AcquireResult represents the result of acquiring a concurrency slot
@@ -104,20 +39,6 @@ type AcquireResult struct {
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
}, nil
}
// Try to acquire immediately
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: s.makeReleaseFunc(key),
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
}
},
}, nil
}
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// tryAcquire attempts to increment the counter if below max
// Uses pre-compiled Lua script for atomicity and performance
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return false, fmt.Errorf("acquire slot failed: %w", err)
return nil, err
}
return result == 1, nil
}
// makeReleaseFunc creates a function to release a concurrency slot
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
return func() {
// Use background context to ensure release even if original context is cancelled
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
// Log error but don't panic - TTL will eventually clean up
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
}
},
}, nil
}
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
val, err := s.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
return val, nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.GetCurrentCount(ctx, key)
}
// GetUserCurrentCount returns current concurrency count for a user
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.GetCurrentCount(ctx, key)
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// ============================================
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.rdb == nil {
if s.cache == nil {
// Redis not available, allow request
return true, nil
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
if err != nil {
// On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result == 1, nil
return result, nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
// Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
// GetUserWaitCount returns current wait queue count for a user
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
return s.GetCurrentCount(ctx, key)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/big"
@@ -13,8 +12,6 @@ import (
"sub2api/internal/model"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
)
var (
@@ -25,19 +22,11 @@ var (
)
const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
)
// verifyCodeData Redis 中存储的验证码数据
type verifyCodeData struct {
Code string `json:"code"`
Attempts int `json:"attempts"`
CreatedAt time.Time `json:"created_at"`
}
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
@@ -52,14 +41,14 @@ type SmtpConfig struct {
// EmailService 邮件服务
type EmailService struct {
settingRepo ports.SettingRepository
rdb *redis.Client
cache ports.EmailCache
}
// NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
return &EmailService{
settingRepo: settingRepo,
rdb: rdb,
cache: cache,
}
}
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
// SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
key := verifyCodeKeyPrefix + email
// 检查是否在冷却期内
existing, err := s.getVerifyCodeData(ctx, key)
existing, err := s.cache.GetVerificationCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
}
// 保存验证码到 Redis
data := &verifyCodeData{
data := &ports.VerificationCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
}
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
// VerifyCode 验证验证码
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
key := verifyCodeKeyPrefix + email
data, err := s.getVerifyCodeData(ctx, key)
data, err := s.cache.GetVerificationCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配
if data.Code != code {
data.Attempts++
_ = s.setVerifyCodeData(ctx, key, data)
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
}
// 验证成功,删除验证码
s.rdb.Del(ctx, key)
_ = s.cache.DeleteVerificationCode(ctx, email)
return nil
}
// getVerifyCodeData 从 Redis 获取验证码数据
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data verifyCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
val, err := json.Marshal(data)
if err != nil {
return err
}
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
return fmt.Sprintf(`

View File

@@ -24,13 +24,11 @@ import (
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionPrefix = "sticky_session:"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
)
@@ -82,7 +80,7 @@ type GatewayService struct {
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client
cache ports.GatewayCache
cfg *config.Config
oauthService *OAuthService
billingService *BillingService
@@ -98,7 +96,7 @@ func NewGatewayService(
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cache ports.GatewayCache,
cfg *config.Config,
oauthService *OAuthService,
billingService *BillingService,
@@ -124,7 +122,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cache: cache,
cfg: cfg,
oauthService: oauthService,
billingService: billingService,
@@ -290,14 +288,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return account, nil
}
}
@@ -347,7 +345,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// 4. 建立粘性绑定
if sessionHash != "" {
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
}
return selected, nil
@@ -526,7 +524,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
var fingerprint *ports.Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)

View File

@@ -11,15 +11,10 @@ import (
"net/http"
"regexp"
"strconv"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefix
identityFingerprintKey = "identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译)
var (
@@ -29,20 +24,8 @@ var (
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
// Fingerprint 存储的指纹数据结构
type Fingerprint struct {
ClientID string `json:"client_id"` // 64位hex客户端ID首次随机生成
UserAgent string `json:"user_agent"` // User-Agent
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
var defaultFingerprint = ports.Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.52.0",
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
// IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct {
rdb *redis.Client
cache ports.IdentityCache
}
// NewIdentityService 创建新的IdentityService
func NewIdentityService(rdb *redis.Client) *IdentityService {
return &IdentityService{rdb: rdb}
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
return &IdentityService{cache: cache}
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
// 尝试从Redis获取缓存的指纹
data, err := s.rdb.Get(ctx, key).Bytes()
if err == nil && len(data) > 0 {
// 缓存存在,解析指纹
var cached Fingerprint
if err := json.Unmarshal(data, &cached); err == nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
if newData, err := json.Marshal(cached); err == nil {
s.rdb.Set(ctx, key, newData, 0) // 永不过期
}
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return &cached, nil
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
_ = s.cache.SetFingerprint(ctx, accountID, cached)
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return cached, nil
}
// 缓存不存在或解析失败,创建新指纹
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID
fp.ClientID = generateClientID()
// 保存到Redis(永不过期)
if data, err := json.Marshal(fp); err == nil {
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
// 保存到缓存(永不过期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
}
// createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &Fingerprint{}
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
fp := &ports.Fingerprint{}
// 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" {
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
}
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
if fp == nil {
return
}

View File

@@ -0,0 +1,16 @@
package ports
import (
"context"
"time"
)
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}

View File

@@ -0,0 +1,31 @@
package ports
import (
"context"
"time"
)
// SubscriptionCacheData represents cached subscription data
type SubscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
GetUserBalance(ctx context.Context, userID int64) (float64, error)
SetUserBalance(ctx context.Context, userID int64, balance float64) error
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
InvalidateUserBalance(ctx context.Context, userID int64) error
// Subscription operations
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
}

View File

@@ -0,0 +1,19 @@
package ports
import "context"
// ConcurrencyCache defines cache operations for concurrency service
type ConcurrencyCache interface {
// Slot management
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,20 @@
package ports
import (
"context"
"time"
)
// VerificationCodeData represents verification code data
type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
}
// EmailCache defines cache operations for email service
type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
}

View File

@@ -0,0 +1,13 @@
package ports
import (
"context"
"time"
)
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}

View File

@@ -0,0 +1,21 @@
package ports
import "context"
// Fingerprint represents account fingerprint data
type Fingerprint struct {
ClientID string
UserAgent string
StainlessLang string
StainlessPackageVersion string
StainlessOS string
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
}
// IdentityCache defines cache operations for identity service
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
}

View File

@@ -0,0 +1,15 @@
package ports
import (
"context"
"time"
)
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
ReleaseRedeemLock(ctx context.Context, code string) error
}

View File

@@ -0,0 +1,12 @@
package ports
import (
"context"
"time"
)
// UpdateCache defines cache operations for update service
type UpdateCache interface {
GetUpdateInfo(ctx context.Context) (string, error)
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
}

View File

@@ -26,11 +26,9 @@ var (
)
const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// GenerateCodesRequest 生成兑换码请求
@@ -53,7 +51,7 @@ type RedeemService struct {
redeemRepo ports.RedeemCodeRepository
userRepo ports.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
cache ports.RedeemCache
billingCacheService *BillingCacheService
}
@@ -62,14 +60,14 @@ func NewRedeemService(
redeemRepo ports.RedeemCodeRepository,
userRepo ports.UserRepository,
subscriptionService *SubscriptionService,
rdb *redis.Client,
cache ports.RedeemCache,
billingCacheService *BillingCacheService,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
cache: cache,
billingCacheService: billingCacheService,
}
}
@@ -140,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
@@ -161,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil {
if s.cache == nil {
return true // 无 Redis 时降级为不加锁
}
key := redeemLockKeyPrefix + code
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
@@ -191,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := redeemLockKeyPrefix + code
s.rdb.Del(ctx, key)
_ = s.cache.ReleaseRedeemLock(ctx, code)
}
// Redeem 使用兑换码

View File

@@ -18,7 +18,7 @@ import (
"strings"
"time"
"github.com/redis/go-redis/v9"
"sub2api/internal/service/ports"
)
const (
@@ -36,15 +36,15 @@ const (
// UpdateService handles software updates
type UpdateService struct {
rdb *redis.Client
cache ports.UpdateCache
currentVersion string
buildType string // "source" for manual builds, "release" for CI builds
}
// NewUpdateService creates a new UpdateService
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
return &UpdateService{
rdb: rdb,
cache: cache,
currentVersion: version,
buildType: buildType,
}
@@ -533,7 +533,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
}
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
data, err := s.cache.GetUpdateInfo(ctx)
if err != nil {
return nil, err
}
@@ -573,7 +573,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
}
data, _ := json.Marshal(cacheData)
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
}
// compareVersions compares two semantic versions