From 7bbf621490f473c0364b349e9dae2a854d34a6df Mon Sep 17 00:00:00 2001 From: Forest Date: Fri, 19 Dec 2025 23:39:28 +0800 Subject: [PATCH] =?UTF-8?q?refactor(backend):=20=E6=B7=BB=E5=8A=A0=20servi?= =?UTF-8?q?ce=20=E7=BC=93=E5=AD=98=E7=AB=AF=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 21 +- .../internal/handler/admin/system_handler.go | 4 +- backend/internal/repository/api_key_cache.go | 51 +++++ backend/internal/repository/billing_cache.go | 174 +++++++++++++++ .../internal/repository/concurrency_cache.go | 132 ++++++++++++ backend/internal/repository/email_cache.go | 48 +++++ backend/internal/repository/gateway_cache.go | 35 +++ backend/internal/repository/identity_cache.go | 47 ++++ backend/internal/repository/redeem_cache.go | 49 +++++ backend/internal/repository/update_cache.go | 28 +++ backend/internal/repository/wire.go | 10 + backend/internal/service/api_key_service.go | 43 ++-- .../internal/service/billing_cache_service.go | 200 ++++-------------- .../internal/service/concurrency_service.go | 188 ++++------------ backend/internal/service/email_service.go | 55 +---- backend/internal/service/gateway_service.go | 16 +- backend/internal/service/identity_service.go | 75 +++---- .../internal/service/ports/api_key_cache.go | 16 ++ .../internal/service/ports/billing_cache.go | 31 +++ .../service/ports/concurrency_cache.go | 19 ++ backend/internal/service/ports/email_cache.go | 20 ++ .../internal/service/ports/gateway_cache.go | 13 ++ .../internal/service/ports/identity_cache.go | 21 ++ .../internal/service/ports/redeem_cache.go | 15 ++ .../internal/service/ports/update_cache.go | 12 ++ backend/internal/service/redeem_service.go | 39 ++-- backend/internal/service/update_service.go | 12 +- 27 files changed, 906 insertions(+), 468 deletions(-) create mode 100644 backend/internal/repository/api_key_cache.go create mode 100644 backend/internal/repository/billing_cache.go create mode 100644 backend/internal/repository/concurrency_cache.go create mode 100644 backend/internal/repository/email_cache.go create mode 100644 backend/internal/repository/gateway_cache.go create mode 100644 backend/internal/repository/identity_cache.go create mode 100644 backend/internal/repository/redeem_cache.go create mode 100644 backend/internal/repository/update_cache.go create mode 100644 backend/internal/service/ports/api_key_cache.go create mode 100644 backend/internal/service/ports/billing_cache.go create mode 100644 backend/internal/service/ports/concurrency_cache.go create mode 100644 backend/internal/service/ports/email_cache.go create mode 100644 backend/internal/service/ports/gateway_cache.go create mode 100644 backend/internal/service/ports/identity_cache.go create mode 100644 backend/internal/service/ports/redeem_cache.go create mode 100644 backend/internal/service/ports/update_cache.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 94db1128..5f13e3cc 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -41,7 +41,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { settingRepository := repository.NewSettingRepository(db) settingService := service.NewSettingService(settingRepository, configConfig) client := infrastructure.ProvideRedis(configConfig) - emailService := service.NewEmailService(settingRepository, client) + emailCache := repository.NewEmailCache(client) + emailService := service.NewEmailService(settingRepository, emailCache) turnstileService := service.NewTurnstileService(settingService) emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) @@ -51,15 +52,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyRepository := repository.NewApiKeyRepository(db) groupRepository := repository.NewGroupRepository(db) userSubscriptionRepository := repository.NewUserSubscriptionRepository(db) - apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig) + apiKeyCache := repository.NewApiKeyCache(client) + apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(db) usageService := service.NewUsageService(usageLogRepository, userRepository) usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(db) - billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository) + billingCache := repository.NewBillingCache(client) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService) + redeemCache := repository.NewRedeemCache(client) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) accountRepository := repository.NewAccountRepository(db) @@ -81,14 +85,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) + gatewayCache := repository.NewGatewayCache(client) pricingService, err := service.ProvidePricingService(configConfig) if err != nil { return nil, err } billingService := service.NewBillingService(configConfig, pricingService) - identityService := service.NewIdentityService(client) - gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) - concurrencyService := service.NewConcurrencyService(client) + identityCache := repository.NewIdentityCache(client) + identityService := service.NewIdentityService(identityCache) + gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) + concurrencyCache := repository.NewConcurrencyCache(client) + concurrencyService := service.NewConcurrencyService(concurrencyCache) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler) diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index cd145e52..4e5f40ba 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -6,6 +6,7 @@ import ( "sub2api/internal/pkg/response" "sub2api/internal/pkg/sysutil" + "sub2api/internal/repository" "sub2api/internal/service" "github.com/gin-gonic/gin" @@ -19,8 +20,9 @@ type SystemHandler struct { // NewSystemHandler creates a new SystemHandler func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler { + updateCache := repository.NewUpdateCache(rdb) return &SystemHandler{ - updateSvc: service.NewUpdateService(rdb, version, buildType), + updateSvc: service.NewUpdateService(updateCache, version, buildType), } } diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go new file mode 100644 index 00000000..479f7d55 --- /dev/null +++ b/backend/internal/repository/api_key_cache.go @@ -0,0 +1,51 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" + apiKeyRateLimitDuration = 24 * time.Hour +) + +type apiKeyCache struct { + rdb *redis.Client +} + +func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache { + return &apiKeyCache{rdb: rdb} +} + +func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + return c.rdb.Get(ctx, key).Int() +} + +func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + pipe := c.rdb.Pipeline() + pipe.Incr(ctx, key) + pipe.Expire(ctx, key, apiKeyRateLimitDuration) + _, err := pipe.Exec(ctx) + return err +} + +func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + return c.rdb.Del(ctx, key).Err() +} + +func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return c.rdb.Incr(ctx, apiKey).Err() +} + +func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return c.rdb.Expire(ctx, apiKey, ttl).Err() +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go new file mode 100644 index 00000000..99ec2b76 --- /dev/null +++ b/backend/internal/repository/billing_cache.go @@ -0,0 +1,174 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "log" + "strconv" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + billingBalanceKeyPrefix = "billing:balance:" + billingSubKeyPrefix = "billing:sub:" + billingCacheTTL = 5 * time.Minute +) + +const ( + subFieldStatus = "status" + subFieldExpiresAt = "expires_at" + subFieldDailyUsage = "daily_usage" + subFieldWeeklyUsage = "weekly_usage" + subFieldMonthlyUsage = "monthly_usage" + subFieldVersion = "version" +) + +var ( + deductBalanceScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current == false then + return 0 + end + local newVal = tonumber(current) - tonumber(ARGV[1]) + redis.call('SET', KEYS[1], newVal) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + `) + + updateSubUsageScript = redis.NewScript(` + local exists = redis.call('EXISTS', KEYS[1]) + if exists == 0 then + return 0 + end + local cost = tonumber(ARGV[1]) + redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + `) +) + +type billingCache struct { + rdb *redis.Client +} + +func NewBillingCache(rdb *redis.Client) ports.BillingCache { + return &billingCache{rdb: rdb} +} + +func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + return 0, err + } + return strconv.ParseFloat(val, 64) +} + +func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() +} + +func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + if err != nil && !errors.Is(err, redis.Nil) { + log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + } + return nil +} + +func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + return c.rdb.Del(ctx, key).Err() +} + +func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) { + key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + result, err := c.rdb.HGetAll(ctx, key).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, redis.Nil + } + return c.parseSubscriptionCache(result) +} + +func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) { + result := &ports.SubscriptionCacheData{} + + result.Status = data[subFieldStatus] + if result.Status == "" { + return nil, errors.New("invalid cache: missing status") + } + + if expiresStr, ok := data[subFieldExpiresAt]; ok { + expiresAt, err := strconv.ParseInt(expiresStr, 10, 64) + if err == nil { + result.ExpiresAt = time.Unix(expiresAt, 0) + } + } + + if dailyStr, ok := data[subFieldDailyUsage]; ok { + result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64) + } + + if weeklyStr, ok := data[subFieldWeeklyUsage]; ok { + result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64) + } + + if monthlyStr, ok := data[subFieldMonthlyUsage]; ok { + result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64) + } + + if versionStr, ok := data[subFieldVersion]; ok { + result.Version, _ = strconv.ParseInt(versionStr, 10, 64) + } + + return result, nil +} + +func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error { + if data == nil { + return nil + } + + key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + + fields := map[string]interface{}{ + subFieldStatus: data.Status, + subFieldExpiresAt: data.ExpiresAt.Unix(), + subFieldDailyUsage: data.DailyUsage, + subFieldWeeklyUsage: data.WeeklyUsage, + subFieldMonthlyUsage: data.MonthlyUsage, + subFieldVersion: data.Version, + } + + pipe := c.rdb.Pipeline() + pipe.HSet(ctx, key, fields) + pipe.Expire(ctx, key, billingCacheTTL) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + if err != nil && !errors.Is(err, redis.Nil) { + log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + } + return nil +} + +func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go new file mode 100644 index 00000000..b5f500a3 --- /dev/null +++ b/backend/internal/repository/concurrency_cache.go @@ -0,0 +1,132 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + accountConcurrencyKeyPrefix = "concurrency:account:" + userConcurrencyKeyPrefix = "concurrency:user:" + waitQueueKeyPrefix = "concurrency:wait:" + concurrencyTTL = 5 * time.Minute +) + +var ( + acquireScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current == false then + current = 0 + else + current = tonumber(current) + end + if current < tonumber(ARGV[1]) then + redis.call('INCR', KEYS[1]) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + end + return 0 + `) + + releaseScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) + + incrementWaitScript = redis.NewScript(` + local waitKey = KEYS[1] + local maxWait = tonumber(ARGV[1]) + local ttl = tonumber(ARGV[2]) + local current = redis.call('GET', waitKey) + if current == false then + current = 0 + else + current = tonumber(current) + end + if current >= maxWait then + return 0 + end + redis.call('INCR', waitKey) + redis.call('EXPIRE', waitKey, ttl) + return 1 + `) + + decrementWaitScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) +) + +type concurrencyCache struct { + rdb *redis.Client +} + +func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache { + return &concurrencyCache{rdb: rdb} +} + +func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) { + key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) + _, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() + return err +} + +func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) + return c.rdb.Get(ctx, key).Int() +} + +func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) { + key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) + _, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() + return err +} + +func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) + return c.rdb.Get(ctx, key).Int() +} + +func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() + return err +} diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go new file mode 100644 index 00000000..2bd1f4c1 --- /dev/null +++ b/backend/internal/repository/email_cache.go @@ -0,0 +1,48 @@ +package repository + +import ( + "context" + "encoding/json" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const verifyCodeKeyPrefix = "verify_code:" + +type emailCache struct { + rdb *redis.Client +} + +func NewEmailCache(rdb *redis.Client) ports.EmailCache { + return &emailCache{rdb: rdb} +} + +func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) { + key := verifyCodeKeyPrefix + email + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + return nil, err + } + var data ports.VerificationCodeData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, err + } + return &data, nil +} + +func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error { + key := verifyCodeKeyPrefix + email + val, err := json.Marshal(data) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error { + key := verifyCodeKeyPrefix + email + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go new file mode 100644 index 00000000..7c87cb8c --- /dev/null +++ b/backend/internal/repository/gateway_cache.go @@ -0,0 +1,35 @@ +package repository + +import ( + "context" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const stickySessionPrefix = "sticky_session:" + +type gatewayCache struct { + rdb *redis.Client +} + +func NewGatewayCache(rdb *redis.Client) ports.GatewayCache { + return &gatewayCache{rdb: rdb} +} + +func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { + key := stickySessionPrefix + sessionHash + return c.rdb.Get(ctx, key).Int64() +} + +func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { + key := stickySessionPrefix + sessionHash + return c.rdb.Set(ctx, key, accountID, ttl).Err() +} + +func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { + key := stickySessionPrefix + sessionHash + return c.rdb.Expire(ctx, key, ttl).Err() +} diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go new file mode 100644 index 00000000..0be0def3 --- /dev/null +++ b/backend/internal/repository/identity_cache.go @@ -0,0 +1,47 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + fingerprintKeyPrefix = "fingerprint:" + fingerprintTTL = 24 * time.Hour +) + +type identityCache struct { + rdb *redis.Client +} + +func NewIdentityCache(rdb *redis.Client) ports.IdentityCache { + return &identityCache{rdb: rdb} +} + +func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) { + key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + return nil, err + } + var fp ports.Fingerprint + if err := json.Unmarshal([]byte(val), &fp); err != nil { + return nil, err + } + return &fp, nil +} + +func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error { + key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) + val, err := json.Marshal(fp) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, val, fingerprintTTL).Err() +} diff --git a/backend/internal/repository/redeem_cache.go b/backend/internal/repository/redeem_cache.go new file mode 100644 index 00000000..3e7d4178 --- /dev/null +++ b/backend/internal/repository/redeem_cache.go @@ -0,0 +1,49 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + redeemRateLimitKeyPrefix = "redeem:ratelimit:" + redeemLockKeyPrefix = "redeem:lock:" + redeemRateLimitDuration = 24 * time.Hour +) + +type redeemCache struct { + rdb *redis.Client +} + +func NewRedeemCache(rdb *redis.Client) ports.RedeemCache { + return &redeemCache{rdb: rdb} +} + +func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) { + key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) + return c.rdb.Get(ctx, key).Int() +} + +func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) + pipe := c.rdb.Pipeline() + pipe.Incr(ctx, key) + pipe.Expire(ctx, key, redeemRateLimitDuration) + _, err := pipe.Exec(ctx) + return err +} + +func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) { + key := redeemLockKeyPrefix + code + return c.rdb.SetNX(ctx, key, 1, ttl).Result() +} + +func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error { + key := redeemLockKeyPrefix + code + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/update_cache.go b/backend/internal/repository/update_cache.go new file mode 100644 index 00000000..39d9c68b --- /dev/null +++ b/backend/internal/repository/update_cache.go @@ -0,0 +1,28 @@ +package repository + +import ( + "context" + "time" + + "sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const updateCacheKey = "update:latest" + +type updateCache struct { + rdb *redis.Client +} + +func NewUpdateCache(rdb *redis.Client) ports.UpdateCache { + return &updateCache{rdb: rdb} +} + +func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) { + return c.rdb.Get(ctx, updateCacheKey).Result() +} + +func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error { + return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index a5909b9b..109afc35 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -19,6 +19,16 @@ var ProviderSet = wire.NewSet( NewUserSubscriptionRepository, wire.Struct(new(Repositories), "*"), + // Cache implementations + NewGatewayCache, + NewBillingCache, + NewApiKeyCache, + NewConcurrencyCache, + NewEmailCache, + NewIdentityCache, + NewRedeemCache, + NewUpdateCache, + // Bind concrete repositories to service port interfaces wire.Bind(new(ports.UserRepository), new(*UserRepository)), wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)), diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 1e98888d..adbd0f46 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -27,9 +27,7 @@ var ( ) const ( - apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:" - apiKeyMaxErrorsPerHour = 20 - apiKeyRateLimitDuration = time.Hour + apiKeyMaxErrorsPerHour = 20 ) // CreateApiKeyRequest 创建API Key请求 @@ -52,7 +50,7 @@ type ApiKeyService struct { userRepo ports.UserRepository groupRepo ports.GroupRepository userSubRepo ports.UserSubscriptionRepository - rdb *redis.Client + cache ports.ApiKeyCache cfg *config.Config } @@ -62,7 +60,7 @@ func NewApiKeyService( userRepo ports.UserRepository, groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, - rdb *redis.Client, + cache ports.ApiKeyCache, cfg *config.Config, ) *ApiKeyService { return &ApiKeyService{ @@ -70,7 +68,7 @@ func NewApiKeyService( userRepo: userRepo, groupRepo: groupRepo, userSubRepo: userSubRepo, - rdb: rdb, + cache: cache, cfg: cfg, } } @@ -113,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { // checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) - - count, err := s.rdb.Get(ctx, key).Int() + count, err := s.cache.GetCreateAttemptCount(ctx, userID) if err != nil && !errors.Is(err, redis.Nil) { // Redis 出错时不阻止用户操作 return nil @@ -134,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) // incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { - if s.rdb == nil { + if s.cache == nil { return } - key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) - - pipe := s.rdb.Pipeline() - pipe.Incr(ctx, key) - pipe.Expire(ctx, key, apiKeyRateLimitDuration) - _, _ = pipe.Exec(ctx) + _ = s.cache.IncrementCreateAttemptCount(ctx, userID) } // canUserBindGroup 检查用户是否可以绑定指定分组 @@ -273,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey } // 缓存到Redis(可选,TTL设置为5分钟) - if s.rdb != nil { + if s.cache != nil { // 这里可以序列化并缓存API Key _ = cacheKey // 使用变量避免未使用错误 } @@ -326,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req if req.Status != nil { apiKey.Status = *req.Status // 如果状态改变,清除Redis缓存 - if s.rdb != nil { - cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) - _ = s.rdb.Del(ctx, cacheKey) + if s.cache != nil { + _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID) } } @@ -355,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro } // 清除Redis缓存 - if s.rdb != nil { - cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) - _ = s.rdb.Del(ctx, cacheKey) + if s.cache != nil { + _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID) } if err := s.apiKeyRepo.Delete(ctx, id); err != nil { @@ -400,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api // IncrementUsage 增加API Key使用次数(可选:用于统计) func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 使用Redis计数器 - if s.rdb != nil { + if s.cache != nil { cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) - if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil { + if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil { return fmt.Errorf("increment usage: %w", err) } // 设置24小时过期 - _ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour) + _ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour) } return nil } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 39384da6..3b16ee12 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -5,30 +5,10 @@ import ( "errors" "fmt" "log" - "strconv" "time" "sub2api/internal/model" "sub2api/internal/service/ports" - - "github.com/redis/go-redis/v9" -) - -// 缓存Key前缀和TTL -const ( - billingBalanceKeyPrefix = "billing:balance:" - billingSubKeyPrefix = "billing:sub:" - billingCacheTTL = 5 * time.Minute -) - -// 订阅缓存Hash字段 -const ( - subFieldStatus = "status" - subFieldExpiresAt = "expires_at" - subFieldDailyUsage = "daily_usage" - subFieldWeeklyUsage = "weekly_usage" - subFieldMonthlyUsage = "monthly_usage" - subFieldVersion = "version" ) // 错误定义 @@ -38,35 +18,6 @@ var ( ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") ) -// 预编译的Lua脚本 -var ( - // deductBalanceScript: 扣减余额缓存,key不存在则忽略 - deductBalanceScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current == false then - return 0 - end - local newVal = tonumber(current) - tonumber(ARGV[1]) - redis.call('SET', KEYS[1], newVal) - redis.call('EXPIRE', KEYS[1], ARGV[2]) - return 1 - `) - - // updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略 - updateSubUsageScript = redis.NewScript(` - local exists = redis.call('EXISTS', KEYS[1]) - if exists == 0 then - return 0 - end - local cost = tonumber(ARGV[1]) - redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost) - redis.call('EXPIRE', KEYS[1], ARGV[2]) - return 1 - `) -) - // subscriptionCacheData 订阅缓存数据结构(内部使用) type subscriptionCacheData struct { Status string @@ -80,15 +31,15 @@ type subscriptionCacheData struct { // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { - rdb *redis.Client + cache ports.BillingCache userRepo ports.UserRepository subRepo ports.UserSubscriptionRepository } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService { +func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService { return &BillingCacheService{ - rdb: rdb, + cache: cache, userRepo: userRepo, subRepo: subRepo, } @@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, su // GetUserBalance 获取用户余额(优先从缓存读取) func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) { - if s.rdb == nil { + if s.cache == nil { // Redis不可用,直接查询数据库 return s.getUserBalanceFromDB(ctx, userID) } - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) - // 尝试从缓存读取 - val, err := s.rdb.Get(ctx, key).Result() + balance, err := s.cache.GetUserBalance(ctx, userID) if err == nil { - balance, parseErr := strconv.ParseFloat(val, 64) - if parseErr == nil { - return balance, nil - } + return balance, nil } - // 缓存未命中或解析错误,从数据库读取 - balance, err := s.getUserBalanceFromDB(ctx, userID) + // 缓存未命中,从数据库读取 + balance, err = s.getUserBalanceFromDB(ctx, userID) if err != nil { return 0, err } @@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i // setBalanceCache 设置余额缓存 func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) { - if s.rdb == nil { + if s.cache == nil { return } - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) - if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil { + if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil { log.Printf("Warning: set balance cache failed for user %d: %v", userID, err) } } // DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存) func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) - - // 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略 - _, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() - if err != nil && !errors.Is(err, redis.Nil) { - log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) - } - return nil + return s.cache.DeductUserBalance(ctx, userID, amount) } // InvalidateUserBalance 失效用户余额缓存 func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - - key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) - if err := s.rdb.Del(ctx, key).Err(); err != nil { + if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil { log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err) return err } @@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID // GetSubscriptionStatus 获取订阅状态(优先从缓存读取) func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { - if s.rdb == nil { + if s.cache == nil { return s.getSubscriptionFromDB(ctx, userID, groupID) } - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) - // 尝试从缓存读取 - result, err := s.rdb.HGetAll(ctx, key).Result() - if err == nil && len(result) > 0 { - data, parseErr := s.parseSubscriptionCache(result) - if parseErr == nil { - return data, nil - } + cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID) + if err == nil && cacheData != nil { + return s.convertFromPortsData(cacheData), nil } // 缓存未命中,从数据库读取 @@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, return data, nil } +func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData { + return &subscriptionCacheData{ + Status: data.Status, + ExpiresAt: data.ExpiresAt, + DailyUsage: data.DailyUsage, + WeeklyUsage: data.WeeklyUsage, + MonthlyUsage: data.MonthlyUsage, + Version: data.Version, + } +} + +func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData { + return &ports.SubscriptionCacheData{ + Status: data.Status, + ExpiresAt: data.ExpiresAt, + DailyUsage: data.DailyUsage, + WeeklyUsage: data.WeeklyUsage, + MonthlyUsage: data.MonthlyUsage, + Version: data.Version, + } +} + // getSubscriptionFromDB 从数据库获取订阅数据 func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) @@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, }, nil } -// parseSubscriptionCache 解析订阅缓存数据 -func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) { - result := &subscriptionCacheData{} - - result.Status = data[subFieldStatus] - if result.Status == "" { - return nil, errors.New("invalid cache: missing status") - } - - if expiresStr, ok := data[subFieldExpiresAt]; ok { - expiresAt, err := strconv.ParseInt(expiresStr, 10, 64) - if err == nil { - result.ExpiresAt = time.Unix(expiresAt, 0) - } - } - - if dailyStr, ok := data[subFieldDailyUsage]; ok { - result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64) - } - - if weeklyStr, ok := data[subFieldWeeklyUsage]; ok { - result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64) - } - - if monthlyStr, ok := data[subFieldMonthlyUsage]; ok { - result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64) - } - - if versionStr, ok := data[subFieldVersion]; ok { - result.Version, _ = strconv.ParseInt(versionStr, 10, 64) - } - - return result, nil -} - // setSubscriptionCache 设置订阅缓存 func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) { - if s.rdb == nil || data == nil { + if s.cache == nil || data == nil { return } - - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) - - fields := map[string]interface{}{ - subFieldStatus: data.Status, - subFieldExpiresAt: data.ExpiresAt.Unix(), - subFieldDailyUsage: data.DailyUsage, - subFieldWeeklyUsage: data.WeeklyUsage, - subFieldMonthlyUsage: data.MonthlyUsage, - subFieldVersion: data.Version, - } - - pipe := s.rdb.Pipeline() - pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) - if _, err := pipe.Exec(ctx); err != nil { + if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil { log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) } } // UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存) func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) - - // 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略 - _, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result() - if err != nil && !errors.Is(err, redis.Nil) { - log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) - } - return nil + return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) } // InvalidateSubscription 失效指定订阅缓存 func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - - key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) - if err := s.rdb.Del(ctx, key).Err(); err != nil { + if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil { log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) return err } diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index d03a34f2..5854ea67 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -2,22 +2,13 @@ package service import ( "context" - "fmt" "log" "time" - "github.com/redis/go-redis/v9" + "sub2api/internal/service/ports" ) const ( - // Redis key prefixes - accountConcurrencyKey = "concurrency:account:" - userConcurrencyKey = "concurrency:user:" - userWaitCountKey = "concurrency:wait:" - - // TTL for concurrency keys (auto-release safety net) - concurrencyKeyTTL = 10 * time.Minute - // Wait polling interval waitPollInterval = 100 * time.Millisecond @@ -28,70 +19,14 @@ const ( defaultExtraWaitSlots = 20 ) -// Pre-compiled Lua scripts for better performance -var ( - // acquireScript: increment counter if below max, return 1 if successful - acquireScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current == false then - current = 0 - else - current = tonumber(current) - end - if current < tonumber(ARGV[1]) then - redis.call('INCR', KEYS[1]) - redis.call('EXPIRE', KEYS[1], ARGV[2]) - return 1 - end - return 0 - `) - - // releaseScript: decrement counter, but don't go below 0 - releaseScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) - - // incrementWaitScript: increment wait counter if below max, return 1 if successful - incrementWaitScript = redis.NewScript(` - local waitKey = KEYS[1] - local maxWait = tonumber(ARGV[1]) - local ttl = tonumber(ARGV[2]) - local current = redis.call('GET', waitKey) - if current == false then - current = 0 - else - current = tonumber(current) - end - if current >= maxWait then - return 0 - end - redis.call('INCR', waitKey) - redis.call('EXPIRE', waitKey, ttl) - return 1 - `) - - // decrementWaitScript: decrement wait counter, but don't go below 0 - decrementWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) -) - // ConcurrencyService manages concurrent request limiting for accounts and users type ConcurrencyService struct { - rdb *redis.Client + cache ports.ConcurrencyCache } // NewConcurrencyService creates a new ConcurrencyService -func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService { - return &ConcurrencyService{rdb: rdb} +func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService { + return &ConcurrencyService{cache: cache} } // AcquireResult represents the result of acquiring a concurrency slot @@ -104,20 +39,6 @@ type AcquireResult struct { // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID) - return s.acquireSlot(ctx, key, maxConcurrency) -} - -// AcquireUserSlot attempts to acquire a concurrency slot for a user. -// If the user is at max concurrency, it waits until a slot is available or timeout. -// Returns a release function that MUST be called when the request completes. -func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) { - key := fmt.Sprintf("%s%d", userConcurrencyKey, userID) - return s.acquireSlot(ctx, key, maxConcurrency) -} - -// acquireSlot is the core implementation for acquiring a concurrency slot -func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) { // If maxConcurrency is 0 or negative, no limit if maxConcurrency <= 0 { return &AcquireResult{ @@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon }, nil } - // Try to acquire immediately - acquired, err := s.tryAcquire(ctx, key, maxConcurrency) + acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } @@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon if acquired { return &AcquireResult{ Acquired: true, - ReleaseFunc: s.makeReleaseFunc(key), + ReleaseFunc: func() { + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil { + log.Printf("Warning: failed to release account slot for %d: %v", accountID, err) + } + }, }, nil } - // Not acquired, return with Acquired=false - // The caller (gateway handler) will handle waiting with ping support return &AcquireResult{ Acquired: false, ReleaseFunc: nil, }, nil } -// tryAcquire attempts to increment the counter if below max -// Uses pre-compiled Lua script for atomicity and performance -func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) { - result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int() +// AcquireUserSlot attempts to acquire a concurrency slot for a user. +// If the user is at max concurrency, it waits until a slot is available or timeout. +// Returns a release function that MUST be called when the request completes. +func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) { + // If maxConcurrency is 0 or negative, no limit + if maxConcurrency <= 0 { + return &AcquireResult{ + Acquired: true, + ReleaseFunc: func() {}, // no-op + }, nil + } + + acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { - return false, fmt.Errorf("acquire slot failed: %w", err) + return nil, err } - return result == 1, nil -} -// makeReleaseFunc creates a function to release a concurrency slot -func (s *ConcurrencyService) makeReleaseFunc(key string) func() { - return func() { - // Use background context to ensure release even if original context is cancelled - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil { - // Log error but don't panic - TTL will eventually clean up - log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err) - } + if acquired { + return &AcquireResult{ + Acquired: true, + ReleaseFunc: func() { + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil { + log.Printf("Warning: failed to release user slot for %d: %v", userID, err) + } + }, + }, nil } -} -// GetCurrentCount returns the current concurrency count for debugging/monitoring -func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) { - val, err := s.rdb.Get(ctx, key).Int() - if err == redis.Nil { - return 0, nil - } - if err != nil { - return 0, err - } - return val, nil -} - -// GetAccountCurrentCount returns current concurrency count for an account -func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) { - key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID) - return s.GetCurrentCount(ctx, key) -} - -// GetUserCurrentCount returns current concurrency count for a user -func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) { - key := fmt.Sprintf("%s%d", userConcurrencyKey, userID) - return s.GetCurrentCount(ctx, key) + return &AcquireResult{ + Acquired: false, + ReleaseFunc: nil, + }, nil } // ============================================ @@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int // Returns true if successful, false if the wait queue is full. // maxWait should be user.Concurrency + defaultExtraWaitSlots func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { - if s.rdb == nil { + if s.cache == nil { // Redis not available, allow request return true, nil } - key := fmt.Sprintf("%s%d", userWaitCountKey, userID) - result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int() + result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait) if err != nil { // On error, allow the request to proceed (fail open) log.Printf("Warning: increment wait count failed for user %d: %v", userID, err) return true, nil } - return result == 1, nil + return result, nil } // DecrementWaitCount decrements the wait queue counter for a user. // Should be called when a request completes or exits the wait queue. func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) { - if s.rdb == nil { + if s.cache == nil { return } - key := fmt.Sprintf("%s%d", userWaitCountKey, userID) // Use background context to ensure decrement even if original context is cancelled bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil { + if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil { log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err) } } -// GetUserWaitCount returns current wait queue count for a user -func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) { - key := fmt.Sprintf("%s%d", userWaitCountKey, userID) - return s.GetCurrentCount(ctx, key) -} - // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 01cd98b4..e9637846 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "crypto/tls" - "encoding/json" "errors" "fmt" "math/big" @@ -13,8 +12,6 @@ import ( "sub2api/internal/model" "sub2api/internal/service/ports" "time" - - "github.com/redis/go-redis/v9" ) var ( @@ -25,19 +22,11 @@ var ( ) const ( - verifyCodeKeyPrefix = "email_verify:" verifyCodeTTL = 15 * time.Minute verifyCodeCooldown = 1 * time.Minute maxVerifyCodeAttempts = 5 ) -// verifyCodeData Redis 中存储的验证码数据 -type verifyCodeData struct { - Code string `json:"code"` - Attempts int `json:"attempts"` - CreatedAt time.Time `json:"created_at"` -} - // SmtpConfig SMTP配置 type SmtpConfig struct { Host string @@ -52,14 +41,14 @@ type SmtpConfig struct { // EmailService 邮件服务 type EmailService struct { settingRepo ports.SettingRepository - rdb *redis.Client + cache ports.EmailCache } // NewEmailService 创建邮件服务实例 -func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService { +func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService { return &EmailService{ settingRepo: settingRepo, - rdb: rdb, + cache: cache, } } @@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) { // SendVerifyCode 发送验证码邮件 func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error { - key := verifyCodeKeyPrefix + email - // 检查是否在冷却期内 - existing, err := s.getVerifyCodeData(ctx, key) + existing, err := s.cache.GetVerificationCode(ctx, email) if err == nil && existing != nil { if time.Since(existing.CreatedAt) < verifyCodeCooldown { return ErrVerifyCodeTooFrequent @@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin } // 保存验证码到 Redis - data := &verifyCodeData{ + data := &ports.VerificationCodeData{ Code: code, Attempts: 0, CreatedAt: time.Now(), } - if err := s.setVerifyCodeData(ctx, key, data); err != nil { + if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) } @@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin // VerifyCode 验证验证码 func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error { - key := verifyCodeKeyPrefix + email - - data, err := s.getVerifyCodeData(ctx, key) + data, err := s.cache.GetVerificationCode(ctx, email) if err != nil || data == nil { return ErrInvalidVerifyCode } @@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 if data.Code != code { data.Attempts++ - _ = s.setVerifyCodeData(ctx, key, data) + _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } @@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error } // 验证成功,删除验证码 - s.rdb.Del(ctx, key) + _ = s.cache.DeleteVerificationCode(ctx, email) return nil } -// getVerifyCodeData 从 Redis 获取验证码数据 -func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) { - val, err := s.rdb.Get(ctx, key).Result() - if err != nil { - return nil, err - } - var data verifyCodeData - if err := json.Unmarshal([]byte(val), &data); err != nil { - return nil, err - } - return &data, nil -} - -// setVerifyCodeData 保存验证码数据到 Redis -func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error { - val, err := json.Marshal(data) - if err != nil { - return err - } - return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err() -} - // buildVerifyCodeEmailBody 构建验证码邮件HTML内容 func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string { return fmt.Sprintf(` diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index aa490ea1..aab17cca 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -24,13 +24,11 @@ import ( "sub2api/internal/service/ports" "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" ) const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionPrefix = "sticky_session:" stickySessionTTL = time.Hour // 粘性会话TTL tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token ) @@ -82,7 +80,7 @@ type GatewayService struct { usageLogRepo ports.UsageLogRepository userRepo ports.UserRepository userSubRepo ports.UserSubscriptionRepository - rdb *redis.Client + cache ports.GatewayCache cfg *config.Config oauthService *OAuthService billingService *BillingService @@ -98,7 +96,7 @@ func NewGatewayService( usageLogRepo ports.UsageLogRepository, userRepo ports.UserRepository, userSubRepo ports.UserSubscriptionRepository, - rdb *redis.Client, + cache ports.GatewayCache, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, @@ -124,7 +122,7 @@ func NewGatewayService( usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, - rdb: rdb, + cache: cache, cfg: cfg, oauthService: oauthService, billingService: billingService, @@ -290,14 +288,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { // 1. 查询粘性会话 if sessionHash != "" { - accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64() + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 { account, err := s.accountRepo.GetByID(ctx, accountID) // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 // 同时检查模型支持 if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // 续期粘性会话 - s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL) + s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) return account, nil } } @@ -347,7 +345,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // 4. 建立粘性绑定 if sessionHash != "" { - s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL) + s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL) } return selected, nil @@ -526,7 +524,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // OAuth账号:应用统一指纹 - var fingerprint *Fingerprint + var fingerprint *ports.Fingerprint if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 022579d3..4bc9da6f 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -11,15 +11,10 @@ import ( "net/http" "regexp" "strconv" + "sub2api/internal/service/ports" "time" - - "github.com/redis/go-redis/v9" ) -const ( - // Redis key prefix - identityFingerprintKey = "identity:fingerprint:" -) // 预编译正则表达式(避免每次调用重新编译) var ( @@ -29,20 +24,8 @@ var ( userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) -// Fingerprint 存储的指纹数据结构 -type Fingerprint struct { - ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成) - UserAgent string `json:"user_agent"` // User-Agent - StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang - StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version - StainlessOS string `json:"x_stainless_os"` // x-stainless-os - StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch - StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime - StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version -} - // 默认指纹值(当客户端未提供时使用) -var defaultFingerprint = Fingerprint{ +var defaultFingerprint = ports.Fingerprint{ UserAgent: "claude-cli/2.0.62 (external, cli)", StainlessLang: "js", StainlessPackageVersion: "0.52.0", @@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{ // IdentityService 管理OAuth账号的请求身份指纹 type IdentityService struct { - rdb *redis.Client + cache ports.IdentityCache } // NewIdentityService 创建新的IdentityService -func NewIdentityService(rdb *redis.Client) *IdentityService { - return &IdentityService{rdb: rdb} +func NewIdentityService(cache ports.IdentityCache) *IdentityService { + return &IdentityService{cache: cache} } // GetOrCreateFingerprint 获取或创建账号的指纹 // 如果缓存存在,检测user-agent版本,新版本则更新 // 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存 -func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) { - key := identityFingerprintKey + strconv.FormatInt(accountID, 10) - - // 尝试从Redis获取缓存的指纹 - data, err := s.rdb.Get(ctx, key).Bytes() - if err == nil && len(data) > 0 { - // 缓存存在,解析指纹 - var cached Fingerprint - if err := json.Unmarshal(data, &cached); err == nil { - // 检查客户端的user-agent是否是更新版本 - clientUA := headers.Get("User-Agent") - if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { - // 更新user-agent - cached.UserAgent = clientUA - // 保存更新后的指纹 - if newData, err := json.Marshal(cached); err == nil { - s.rdb.Set(ctx, key, newData, 0) // 永不过期 - } - log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA) - } - return &cached, nil +func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) { + // 尝试从缓存获取指纹 + cached, err := s.cache.GetFingerprint(ctx, accountID) + if err == nil && cached != nil { + // 检查客户端的user-agent是否是更新版本 + clientUA := headers.Get("User-Agent") + if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { + // 更新user-agent + cached.UserAgent = clientUA + // 保存更新后的指纹 + _ = s.cache.SetFingerprint(ctx, accountID, cached) + log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA) } + return cached, nil } // 缓存不存在或解析失败,创建新指纹 @@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 生成随机ClientID fp.ClientID = generateClientID() - // 保存到Redis(永不过期) - if data, err := json.Marshal(fp); err == nil { - if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil { - log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err) - } + // 保存到缓存(永不过期) + if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil { + log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err) } log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) @@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID } // createFingerprintFromHeaders 从请求头创建指纹 -func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint { - fp := &Fingerprint{} +func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint { + fp := &ports.Fingerprint{} // 获取User-Agent if ua := headers.Get("User-Agent"); ua != "" { @@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { } // ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头) -func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { +func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) { if fp == nil { return } diff --git a/backend/internal/service/ports/api_key_cache.go b/backend/internal/service/ports/api_key_cache.go new file mode 100644 index 00000000..0b9efb24 --- /dev/null +++ b/backend/internal/service/ports/api_key_cache.go @@ -0,0 +1,16 @@ +package ports + +import ( + "context" + "time" +) + +// ApiKeyCache defines cache operations for API key service +type ApiKeyCache interface { + GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) + IncrementCreateAttemptCount(ctx context.Context, userID int64) error + DeleteCreateAttemptCount(ctx context.Context, userID int64) error + + IncrementDailyUsage(ctx context.Context, apiKey string) error + SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error +} diff --git a/backend/internal/service/ports/billing_cache.go b/backend/internal/service/ports/billing_cache.go new file mode 100644 index 00000000..b357006a --- /dev/null +++ b/backend/internal/service/ports/billing_cache.go @@ -0,0 +1,31 @@ +package ports + +import ( + "context" + "time" +) + +// SubscriptionCacheData represents cached subscription data +type SubscriptionCacheData struct { + Status string + ExpiresAt time.Time + DailyUsage float64 + WeeklyUsage float64 + MonthlyUsage float64 + Version int64 +} + +// BillingCache defines cache operations for billing service +type BillingCache interface { + // Balance operations + GetUserBalance(ctx context.Context, userID int64) (float64, error) + SetUserBalance(ctx context.Context, userID int64, balance float64) error + DeductUserBalance(ctx context.Context, userID int64, amount float64) error + InvalidateUserBalance(ctx context.Context, userID int64) error + + // Subscription operations + GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) + SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error + UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error + InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error +} diff --git a/backend/internal/service/ports/concurrency_cache.go b/backend/internal/service/ports/concurrency_cache.go new file mode 100644 index 00000000..313737b7 --- /dev/null +++ b/backend/internal/service/ports/concurrency_cache.go @@ -0,0 +1,19 @@ +package ports + +import "context" + +// ConcurrencyCache defines cache operations for concurrency service +type ConcurrencyCache interface { + // Slot management + AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) + ReleaseAccountSlot(ctx context.Context, accountID int64) error + GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + + AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) + ReleaseUserSlot(ctx context.Context, userID int64) error + GetUserConcurrency(ctx context.Context, userID int64) (int, error) + + // Wait queue + IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) + DecrementWaitCount(ctx context.Context, userID int64) error +} diff --git a/backend/internal/service/ports/email_cache.go b/backend/internal/service/ports/email_cache.go new file mode 100644 index 00000000..a48a3761 --- /dev/null +++ b/backend/internal/service/ports/email_cache.go @@ -0,0 +1,20 @@ +package ports + +import ( + "context" + "time" +) + +// VerificationCodeData represents verification code data +type VerificationCodeData struct { + Code string + Attempts int + CreatedAt time.Time +} + +// EmailCache defines cache operations for email service +type EmailCache interface { + GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) + SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error + DeleteVerificationCode(ctx context.Context, email string) error +} diff --git a/backend/internal/service/ports/gateway_cache.go b/backend/internal/service/ports/gateway_cache.go new file mode 100644 index 00000000..9df3aa40 --- /dev/null +++ b/backend/internal/service/ports/gateway_cache.go @@ -0,0 +1,13 @@ +package ports + +import ( + "context" + "time" +) + +// GatewayCache defines cache operations for gateway service +type GatewayCache interface { + GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) + SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error + RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error +} diff --git a/backend/internal/service/ports/identity_cache.go b/backend/internal/service/ports/identity_cache.go new file mode 100644 index 00000000..a8fbc611 --- /dev/null +++ b/backend/internal/service/ports/identity_cache.go @@ -0,0 +1,21 @@ +package ports + +import "context" + +// Fingerprint represents account fingerprint data +type Fingerprint struct { + ClientID string + UserAgent string + StainlessLang string + StainlessPackageVersion string + StainlessOS string + StainlessArch string + StainlessRuntime string + StainlessRuntimeVersion string +} + +// IdentityCache defines cache operations for identity service +type IdentityCache interface { + GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error) + SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error +} diff --git a/backend/internal/service/ports/redeem_cache.go b/backend/internal/service/ports/redeem_cache.go new file mode 100644 index 00000000..a90ad1de --- /dev/null +++ b/backend/internal/service/ports/redeem_cache.go @@ -0,0 +1,15 @@ +package ports + +import ( + "context" + "time" +) + +// RedeemCache defines cache operations for redeem service +type RedeemCache interface { + GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) + IncrementRedeemAttemptCount(ctx context.Context, userID int64) error + + AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) + ReleaseRedeemLock(ctx context.Context, code string) error +} diff --git a/backend/internal/service/ports/update_cache.go b/backend/internal/service/ports/update_cache.go new file mode 100644 index 00000000..125bbc62 --- /dev/null +++ b/backend/internal/service/ports/update_cache.go @@ -0,0 +1,12 @@ +package ports + +import ( + "context" + "time" +) + +// UpdateCache defines cache operations for update service +type UpdateCache interface { + GetUpdateInfo(ctx context.Context) (string, error) + SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 72137f1b..de4c7249 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -26,11 +26,9 @@ var ( ) const ( - redeemRateLimitKeyPrefix = "redeem:rate_limit:" - redeemLockKeyPrefix = "redeem:lock:" - redeemMaxErrorsPerHour = 20 - redeemRateLimitDuration = time.Hour - redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 + redeemMaxErrorsPerHour = 20 + redeemRateLimitDuration = time.Hour + redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 ) // GenerateCodesRequest 生成兑换码请求 @@ -53,7 +51,7 @@ type RedeemService struct { redeemRepo ports.RedeemCodeRepository userRepo ports.UserRepository subscriptionService *SubscriptionService - rdb *redis.Client + cache ports.RedeemCache billingCacheService *BillingCacheService } @@ -62,14 +60,14 @@ func NewRedeemService( redeemRepo ports.RedeemCodeRepository, userRepo ports.UserRepository, subscriptionService *SubscriptionService, - rdb *redis.Client, + cache ports.RedeemCache, billingCacheService *BillingCacheService, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, userRepo: userRepo, subscriptionService: subscriptionService, - rdb: rdb, + cache: cache, billingCacheService: billingCacheService, } } @@ -140,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { - if s.rdb == nil { + if s.cache == nil { return nil } - key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) - - count, err := s.rdb.Get(ctx, key).Int() + count, err := s.cache.GetRedeemAttemptCount(ctx, userID) if err != nil && !errors.Is(err, redis.Nil) { // Redis 出错时不阻止用户操作 return nil @@ -161,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) // incrementRedeemErrorCount 增加用户兑换错误计数 func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) { - if s.rdb == nil { + if s.cache == nil { return } - key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) - - pipe := s.rdb.Pipeline() - pipe.Incr(ctx, key) - pipe.Expire(ctx, key, redeemRateLimitDuration) - _, _ = pipe.Exec(ctx) + _ = s.cache.IncrementRedeemAttemptCount(ctx, userID) } // acquireRedeemLock 尝试获取兑换码的分布式锁 // 返回 true 表示获取成功,false 表示锁已被占用 func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool { - if s.rdb == nil { + if s.cache == nil { return true // 无 Redis 时降级为不加锁 } - key := redeemLockKeyPrefix + code - ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result() + ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration) if err != nil { // Redis 出错时不阻止操作,依赖数据库层面的状态检查 return true @@ -191,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool // releaseRedeemLock 释放兑换码的分布式锁 func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { - if s.rdb == nil { + if s.cache == nil { return } - key := redeemLockKeyPrefix + code - s.rdb.Del(ctx, key) + _ = s.cache.ReleaseRedeemLock(ctx, code) } // Redeem 使用兑换码 diff --git a/backend/internal/service/update_service.go b/backend/internal/service/update_service.go index 0cadff47..c65799cc 100644 --- a/backend/internal/service/update_service.go +++ b/backend/internal/service/update_service.go @@ -18,7 +18,7 @@ import ( "strings" "time" - "github.com/redis/go-redis/v9" + "sub2api/internal/service/ports" ) const ( @@ -36,15 +36,15 @@ const ( // UpdateService handles software updates type UpdateService struct { - rdb *redis.Client + cache ports.UpdateCache currentVersion string buildType string // "source" for manual builds, "release" for CI builds } // NewUpdateService creates a new UpdateService -func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService { +func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService { return &UpdateService{ - rdb: rdb, + cache: cache, currentVersion: version, buildType: buildType, } @@ -533,7 +533,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { } func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) { - data, err := s.rdb.Get(ctx, updateCacheKey).Result() + data, err := s.cache.GetUpdateInfo(ctx) if err != nil { return nil, err } @@ -573,7 +573,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) { } data, _ := json.Marshal(cacheData) - s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second) + s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second) } // compareVersions compares two semantic versions