Merge PR #8: refactor(backend): 添加 service 缓存端口
This commit is contained in:
@@ -41,7 +41,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
settingRepository := repository.NewSettingRepository(db)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
client := infrastructure.ProvideRedis(configConfig)
|
||||
emailService := service.NewEmailService(settingRepository, client)
|
||||
emailCache := repository.NewEmailCache(client)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileService := service.NewTurnstileService(settingService)
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
@@ -51,15 +52,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
|
||||
apiKeyCache := repository.NewApiKeyCache(client)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(client)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
@@ -81,14 +85,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingService, err := service.ProvidePricingService(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(client)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||
concurrencyService := service.NewConcurrencyService(client)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/sysutil"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -19,8 +20,9 @@ type SystemHandler struct {
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
||||
updateCache := repository.NewUpdateCache(rdb)
|
||||
return &SystemHandler{
|
||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
||||
updateSvc: service.NewUpdateService(updateCache, version, buildType),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
51
backend/internal/repository/api_key_cache.go
Normal file
51
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
174
backend/internal/repository/billing_cache.go
Normal file
174
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||
result := &ports.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
132
backend/internal/repository/concurrency_cache.go
Normal file
132
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
concurrencyTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
48
backend/internal/repository/email_cache.go
Normal file
48
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data ports.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
35
backend/internal/repository/gateway_cache.go
Normal file
35
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
47
backend/internal/repository/identity_cache.go
Normal file
47
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp ports.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
49
backend/internal/repository/redeem_cache.go
Normal file
49
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
28
backend/internal/repository/update_cache.go
Normal file
28
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const updateCacheKey = "update:latest"
|
||||
|
||||
type updateCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||
return &updateCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||
}
|
||||
|
||||
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||
}
|
||||
@@ -19,6 +19,16 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
wire.Struct(new(Repositories), "*"),
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||
|
||||
@@ -27,9 +27,7 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyRateLimitDuration = time.Hour
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
@@ -52,7 +50,7 @@ type ApiKeyService struct {
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cache ports.ApiKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
@@ -62,7 +60,7 @@ func NewApiKeyService(
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cache ports.ApiKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
@@ -70,7 +68,7 @@ func NewApiKeyService(
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -113,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -134,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
@@ -273,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
@@ -326,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// 清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
@@ -400,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
||||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
||||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,30 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 缓存Key前缀和TTL
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// 订阅缓存Hash字段
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -38,35 +18,6 @@ var (
|
||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||
)
|
||||
|
||||
// 预编译的Lua脚本
|
||||
var (
|
||||
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.BillingCache
|
||||
userRepo ports.UserRepository
|
||||
subRepo ports.UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
}
|
||||
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, su
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||
if err == nil {
|
||||
balance, parseErr := strconv.ParseFloat(val, 64)
|
||||
if parseErr == nil {
|
||||
return balance, nil
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中或解析错误,从数据库读取
|
||||
balance, err := s.getUserBalanceFromDB(ctx, userID)
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
|
||||
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
result, err := s.rdb.HGetAll(ctx, key).Result()
|
||||
if err == nil && len(result) > 0 {
|
||||
data, parseErr := s.parseSubscriptionCache(result)
|
||||
if parseErr == nil {
|
||||
return data, nil
|
||||
}
|
||||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
if err == nil && cacheData != nil {
|
||||
return s.convertFromPortsData(cacheData), nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
|
||||
return &ports.SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscriptionCache 解析订阅缓存数据
|
||||
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
|
||||
result := &subscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.rdb == nil || data == nil {
|
||||
if s.cache == nil || data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
|
||||
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,22 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefixes
|
||||
accountConcurrencyKey = "concurrency:account:"
|
||||
userConcurrencyKey = "concurrency:user:"
|
||||
userWaitCountKey = "concurrency:wait:"
|
||||
|
||||
// TTL for concurrency keys (auto-release safety net)
|
||||
concurrencyKeyTTL = 10 * time.Minute
|
||||
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
@@ -28,70 +19,14 @@ const (
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// Pre-compiled Lua scripts for better performance
|
||||
var (
|
||||
// acquireScript: increment counter if below max, return 1 if successful
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// releaseScript: decrement counter, but don't go below 0
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
||||
return &ConcurrencyService{rdb: rdb}
|
||||
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@@ -104,20 +39,6 @@ type AcquireResult struct {
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: s.makeReleaseFunc(key),
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not acquired, return with Acquired=false
|
||||
// The caller (gateway handler) will handle waiting with ping support
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryAcquire attempts to increment the counter if below max
|
||||
// Uses pre-compiled Lua script for atomicity and performance
|
||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// makeReleaseFunc creates a function to release a concurrency slot
|
||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
||||
return func() {
|
||||
// Use background context to ensure release even if original context is cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
||||
// Log error but don't panic - TTL will eventually clean up
|
||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
||||
}
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetAccountCurrentCount returns current concurrency count for an account
|
||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserCurrentCount returns current concurrency count for a user
|
||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result == 1, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserWaitCount returns current wait queue count for a user
|
||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -13,8 +12,6 @@ import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -25,19 +22,11 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// verifyCodeData Redis 中存储的验证码数据
|
||||
type verifyCodeData struct {
|
||||
Code string `json:"code"`
|
||||
Attempts int `json:"attempts"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
Host string
|
||||
@@ -52,14 +41,14 @@ type SmtpConfig struct {
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo ports.SettingRepository
|
||||
rdb *redis.Client
|
||||
cache ports.EmailCache
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.getVerifyCodeData(ctx, key)
|
||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||
return ErrVerifyCodeTooFrequent
|
||||
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &verifyCodeData{
|
||||
data := &ports.VerificationCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
data, err := s.getVerifyCodeData(ctx, key)
|
||||
data, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.setVerifyCodeData(ctx, key, data)
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getVerifyCodeData 从 Redis 获取验证码数据
|
||||
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data verifyCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// setVerifyCodeData 保存验证码数据到 Redis
|
||||
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
|
||||
}
|
||||
|
||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
|
||||
@@ -24,13 +24,11 @@ import (
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionPrefix = "sticky_session:"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
||||
)
|
||||
@@ -82,7 +80,7 @@ type GatewayService struct {
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
billingService *BillingService
|
||||
@@ -98,7 +96,7 @@ func NewGatewayService(
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
@@ -124,7 +122,7 @@ func NewGatewayService(
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
billingService: billingService,
|
||||
@@ -290,14 +288,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -347,7 +345,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -526,7 +524,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
var fingerprint *Fingerprint
|
||||
var fingerprint *ports.Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
|
||||
@@ -11,15 +11,10 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefix
|
||||
identityFingerprintKey = "identity:fingerprint:"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
@@ -29,20 +24,8 @@ var (
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
|
||||
// Fingerprint 存储的指纹数据结构
|
||||
type Fingerprint struct {
|
||||
ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成)
|
||||
UserAgent string `json:"user_agent"` // User-Agent
|
||||
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
|
||||
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
|
||||
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
|
||||
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
|
||||
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
|
||||
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
|
||||
}
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
var defaultFingerprint = ports.Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.IdentityCache
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(rdb *redis.Client) *IdentityService {
|
||||
return &IdentityService{rdb: rdb}
|
||||
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
|
||||
return &IdentityService{cache: cache}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
|
||||
|
||||
// 尝试从Redis获取缓存的指纹
|
||||
data, err := s.rdb.Get(ctx, key).Bytes()
|
||||
if err == nil && len(data) > 0 {
|
||||
// 缓存存在,解析指纹
|
||||
var cached Fingerprint
|
||||
if err := json.Unmarshal(data, &cached); err == nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
if newData, err := json.Marshal(cached); err == nil {
|
||||
s.rdb.Set(ctx, key, newData, 0) // 永不过期
|
||||
}
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return &cached, nil
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 缓存不存在或解析失败,创建新指纹
|
||||
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
|
||||
// 保存到Redis(永不过期)
|
||||
if data, err := json.Marshal(fp); err == nil {
|
||||
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
// 保存到缓存(永不过期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
|
||||
fp := &ports.Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
16
backend/internal/service/ports/api_key_cache.go
Normal file
16
backend/internal/service/ports/api_key_cache.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
31
backend/internal/service/ports/billing_cache.go
Normal file
31
backend/internal/service/ports/billing_cache.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
}
|
||||
19
backend/internal/service/ports/concurrency_cache.go
Normal file
19
backend/internal/service/ports/concurrency_cache.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
type ConcurrencyCache interface {
|
||||
// Slot management
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
20
backend/internal/service/ports/email_cache.go
Normal file
20
backend/internal/service/ports/email_cache.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
}
|
||||
13
backend/internal/service/ports/gateway_cache.go
Normal file
13
backend/internal/service/ports/gateway_cache.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
21
backend/internal/service/ports/identity_cache.go
Normal file
21
backend/internal/service/ports/identity_cache.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
type Fingerprint struct {
|
||||
ClientID string
|
||||
UserAgent string
|
||||
StainlessLang string
|
||||
StainlessPackageVersion string
|
||||
StainlessOS string
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
}
|
||||
15
backend/internal/service/ports/redeem_cache.go
Normal file
15
backend/internal/service/ports/redeem_cache.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedeemCache defines cache operations for redeem service
|
||||
type RedeemCache interface {
|
||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
|
||||
ReleaseRedeemLock(ctx context.Context, code string) error
|
||||
}
|
||||
12
backend/internal/service/ports/update_cache.go
Normal file
12
backend/internal/service/ports/update_cache.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateCache defines cache operations for update service
|
||||
type UpdateCache interface {
|
||||
GetUpdateInfo(ctx context.Context) (string, error)
|
||||
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
|
||||
}
|
||||
@@ -26,11 +26,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
@@ -53,7 +51,7 @@ type RedeemService struct {
|
||||
redeemRepo ports.RedeemCodeRepository
|
||||
userRepo ports.UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
rdb *redis.Client
|
||||
cache ports.RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
@@ -62,14 +60,14 @@ func NewRedeemService(
|
||||
redeemRepo ports.RedeemCodeRepository,
|
||||
userRepo ports.UserRepository,
|
||||
subscriptionService *SubscriptionService,
|
||||
rdb *redis.Client,
|
||||
cache ports.RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
@@ -140,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
|
||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -161,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||||
// 返回 true 表示获取成功,false 表示锁已被占用
|
||||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return true // 无 Redis 时降级为不加锁
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
|
||||
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||||
return true
|
||||
@@ -191,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
|
||||
|
||||
// releaseRedeemLock 释放兑换码的分布式锁
|
||||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.ReleaseRedeemLock(ctx, code)
|
||||
}
|
||||
|
||||
// Redeem 使用兑换码
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,15 +36,15 @@ const (
|
||||
|
||||
// UpdateService handles software updates
|
||||
type UpdateService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.UpdateCache
|
||||
currentVersion string
|
||||
buildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewUpdateService creates a new UpdateService
|
||||
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
|
||||
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
|
||||
return &UpdateService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
currentVersion: version,
|
||||
buildType: buildType,
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
|
||||
data, err := s.cache.GetUpdateInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -573,7 +573,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
|
||||
Reference in New Issue
Block a user