Merge PR #8: refactor(backend): 添加 service 缓存端口
This commit is contained in:
@@ -41,7 +41,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
settingRepository := repository.NewSettingRepository(db)
|
settingRepository := repository.NewSettingRepository(db)
|
||||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||||
client := infrastructure.ProvideRedis(configConfig)
|
client := infrastructure.ProvideRedis(configConfig)
|
||||||
emailService := service.NewEmailService(settingRepository, client)
|
emailCache := repository.NewEmailCache(client)
|
||||||
|
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||||
turnstileService := service.NewTurnstileService(settingService)
|
turnstileService := service.NewTurnstileService(settingService)
|
||||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||||
@@ -51,15 +52,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||||
groupRepository := repository.NewGroupRepository(db)
|
groupRepository := repository.NewGroupRepository(db)
|
||||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
||||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
|
apiKeyCache := repository.NewApiKeyCache(client)
|
||||||
|
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||||
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
|
billingCache := repository.NewBillingCache(client)
|
||||||
|
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
|
redeemCache := repository.NewRedeemCache(client)
|
||||||
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
accountRepository := repository.NewAccountRepository(db)
|
accountRepository := repository.NewAccountRepository(db)
|
||||||
@@ -81,14 +85,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||||
|
gatewayCache := repository.NewGatewayCache(client)
|
||||||
pricingService, err := service.ProvidePricingService(configConfig)
|
pricingService, err := service.ProvidePricingService(configConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(client)
|
identityCache := repository.NewIdentityCache(client)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
concurrencyService := service.NewConcurrencyService(client)
|
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||||
|
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||||
|
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/pkg/sysutil"
|
"sub2api/internal/pkg/sysutil"
|
||||||
|
"sub2api/internal/repository"
|
||||||
"sub2api/internal/service"
|
"sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -19,8 +20,9 @@ type SystemHandler struct {
|
|||||||
|
|
||||||
// NewSystemHandler creates a new SystemHandler
|
// NewSystemHandler creates a new SystemHandler
|
||||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
||||||
|
updateCache := repository.NewUpdateCache(rdb)
|
||||||
return &SystemHandler{
|
return &SystemHandler{
|
||||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
updateSvc: service.NewUpdateService(updateCache, version, buildType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
51
backend/internal/repository/api_key_cache.go
Normal file
51
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||||
|
apiKeyRateLimitDuration = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||||
|
return &apiKeyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||||
|
return c.rdb.Incr(ctx, apiKey).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||||
|
}
|
||||||
174
backend/internal/repository/billing_cache.go
Normal file
174
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
billingBalanceKeyPrefix = "billing:balance:"
|
||||||
|
billingSubKeyPrefix = "billing:sub:"
|
||||||
|
billingCacheTTL = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
subFieldStatus = "status"
|
||||||
|
subFieldExpiresAt = "expires_at"
|
||||||
|
subFieldDailyUsage = "daily_usage"
|
||||||
|
subFieldWeeklyUsage = "weekly_usage"
|
||||||
|
subFieldMonthlyUsage = "monthly_usage"
|
||||||
|
subFieldVersion = "version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
deductBalanceScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current == false then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||||
|
redis.call('SET', KEYS[1], newVal)
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
updateSubUsageScript = redis.NewScript(`
|
||||||
|
local exists = redis.call('EXISTS', KEYS[1])
|
||||||
|
if exists == 0 then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local cost = tonumber(ARGV[1])
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type billingCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||||
|
return &billingCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(val, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
return c.parseSubscriptionCache(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||||
|
result := &ports.SubscriptionCacheData{}
|
||||||
|
|
||||||
|
result.Status = data[subFieldStatus]
|
||||||
|
if result.Status == "" {
|
||||||
|
return nil, errors.New("invalid cache: missing status")
|
||||||
|
}
|
||||||
|
|
||||||
|
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||||
|
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||||
|
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||||
|
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||||
|
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if versionStr, ok := data[subFieldVersion]; ok {
|
||||||
|
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||||
|
if data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
subFieldStatus: data.Status,
|
||||||
|
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||||
|
subFieldDailyUsage: data.DailyUsage,
|
||||||
|
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||||
|
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||||
|
subFieldVersion: data.Version,
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.HSet(ctx, key, fields)
|
||||||
|
pipe.Expire(ctx, key, billingCacheTTL)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
132
backend/internal/repository/concurrency_cache.go
Normal file
132
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||||
|
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||||
|
waitQueueKeyPrefix = "concurrency:wait:"
|
||||||
|
concurrencyTTL = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
acquireScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current == false then
|
||||||
|
current = 0
|
||||||
|
else
|
||||||
|
current = tonumber(current)
|
||||||
|
end
|
||||||
|
if current < tonumber(ARGV[1]) then
|
||||||
|
redis.call('INCR', KEYS[1])
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
`)
|
||||||
|
|
||||||
|
releaseScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current ~= false and tonumber(current) > 0 then
|
||||||
|
redis.call('DECR', KEYS[1])
|
||||||
|
end
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
incrementWaitScript = redis.NewScript(`
|
||||||
|
local waitKey = KEYS[1]
|
||||||
|
local maxWait = tonumber(ARGV[1])
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
local current = redis.call('GET', waitKey)
|
||||||
|
if current == false then
|
||||||
|
current = 0
|
||||||
|
else
|
||||||
|
current = tonumber(current)
|
||||||
|
end
|
||||||
|
if current >= maxWait then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
redis.call('INCR', waitKey)
|
||||||
|
redis.call('EXPIRE', waitKey, ttl)
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
decrementWaitScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current ~= false and tonumber(current) > 0 then
|
||||||
|
redis.call('DECR', KEYS[1])
|
||||||
|
end
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type concurrencyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||||
|
return &concurrencyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||||
|
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||||
|
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||||
|
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||||
|
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
48
backend/internal/repository/email_cache.go
Normal file
48
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const verifyCodeKeyPrefix = "verify_code:"
|
||||||
|
|
||||||
|
type emailCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||||
|
return &emailCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var data ports.VerificationCodeData
|
||||||
|
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
val, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
35
backend/internal/repository/gateway_cache.go
Normal file
35
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const stickySessionPrefix = "sticky_session:"
|
||||||
|
|
||||||
|
type gatewayCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||||
|
return &gatewayCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Get(ctx, key).Int64()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||||
|
}
|
||||||
47
backend/internal/repository/identity_cache.go
Normal file
47
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fingerprintKeyPrefix = "fingerprint:"
|
||||||
|
fingerprintTTL = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type identityCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||||
|
return &identityCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var fp ports.Fingerprint
|
||||||
|
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &fp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||||
|
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||||
|
val, err := json.Marshal(fp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||||
|
}
|
||||||
49
backend/internal/repository/redeem_cache.go
Normal file
49
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||||
|
redeemLockKeyPrefix = "redeem:lock:"
|
||||||
|
redeemRateLimitDuration = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type redeemCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||||
|
return &redeemCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||||
|
key := redeemLockKeyPrefix + code
|
||||||
|
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||||
|
key := redeemLockKeyPrefix + code
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
28
backend/internal/repository/update_cache.go
Normal file
28
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const updateCacheKey = "update:latest"
|
||||||
|
|
||||||
|
type updateCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||||
|
return &updateCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||||
|
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||||
|
}
|
||||||
@@ -19,6 +19,16 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewUserSubscriptionRepository,
|
NewUserSubscriptionRepository,
|
||||||
wire.Struct(new(Repositories), "*"),
|
wire.Struct(new(Repositories), "*"),
|
||||||
|
|
||||||
|
// Cache implementations
|
||||||
|
NewGatewayCache,
|
||||||
|
NewBillingCache,
|
||||||
|
NewApiKeyCache,
|
||||||
|
NewConcurrencyCache,
|
||||||
|
NewEmailCache,
|
||||||
|
NewIdentityCache,
|
||||||
|
NewRedeemCache,
|
||||||
|
NewUpdateCache,
|
||||||
|
|
||||||
// Bind concrete repositories to service port interfaces
|
// Bind concrete repositories to service port interfaces
|
||||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||||
|
|||||||
@@ -27,9 +27,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
apiKeyMaxErrorsPerHour = 20
|
||||||
apiKeyMaxErrorsPerHour = 20
|
|
||||||
apiKeyRateLimitDuration = time.Hour
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateApiKeyRequest 创建API Key请求
|
// CreateApiKeyRequest 创建API Key请求
|
||||||
@@ -52,7 +50,7 @@ type ApiKeyService struct {
|
|||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
groupRepo ports.GroupRepository
|
groupRepo ports.GroupRepository
|
||||||
userSubRepo ports.UserSubscriptionRepository
|
userSubRepo ports.UserSubscriptionRepository
|
||||||
rdb *redis.Client
|
cache ports.ApiKeyCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +60,7 @@ func NewApiKeyService(
|
|||||||
userRepo ports.UserRepository,
|
userRepo ports.UserRepository,
|
||||||
groupRepo ports.GroupRepository,
|
groupRepo ports.GroupRepository,
|
||||||
userSubRepo ports.UserSubscriptionRepository,
|
userSubRepo ports.UserSubscriptionRepository,
|
||||||
rdb *redis.Client,
|
cache ports.ApiKeyCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *ApiKeyService {
|
) *ApiKeyService {
|
||||||
return &ApiKeyService{
|
return &ApiKeyService{
|
||||||
@@ -70,7 +68,7 @@ func NewApiKeyService(
|
|||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -113,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
|||||||
|
|
||||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||||
|
|
||||||
count, err := s.rdb.Get(ctx, key).Int()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
// Redis 出错时不阻止用户操作
|
// Redis 出错时不阻止用户操作
|
||||||
return nil
|
return nil
|
||||||
@@ -134,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
|||||||
|
|
||||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||||
|
|
||||||
pipe := s.rdb.Pipeline()
|
|
||||||
pipe.Incr(ctx, key)
|
|
||||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
|
||||||
_, _ = pipe.Exec(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||||
@@ -273,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
// 这里可以序列化并缓存API Key
|
// 这里可以序列化并缓存API Key
|
||||||
_ = cacheKey // 使用变量避免未使用错误
|
_ = cacheKey // 使用变量避免未使用错误
|
||||||
}
|
}
|
||||||
@@ -326,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
if req.Status != nil {
|
if req.Status != nil {
|
||||||
apiKey.Status = *req.Status
|
apiKey.Status = *req.Status
|
||||||
// 如果状态改变,清除Redis缓存
|
// 如果状态改变,清除Redis缓存
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||||
_ = s.rdb.Del(ctx, cacheKey)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清除Redis缓存
|
// 清除Redis缓存
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||||
_ = s.rdb.Del(ctx, cacheKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||||
@@ -400,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
|||||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||||
// 使用Redis计数器
|
// 使用Redis计数器
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||||
return fmt.Errorf("increment usage: %w", err)
|
return fmt.Errorf("increment usage: %w", err)
|
||||||
}
|
}
|
||||||
// 设置24小时过期
|
// 设置24小时过期
|
||||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,30 +5,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 缓存Key前缀和TTL
|
|
||||||
const (
|
|
||||||
billingBalanceKeyPrefix = "billing:balance:"
|
|
||||||
billingSubKeyPrefix = "billing:sub:"
|
|
||||||
billingCacheTTL = 5 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
// 订阅缓存Hash字段
|
|
||||||
const (
|
|
||||||
subFieldStatus = "status"
|
|
||||||
subFieldExpiresAt = "expires_at"
|
|
||||||
subFieldDailyUsage = "daily_usage"
|
|
||||||
subFieldWeeklyUsage = "weekly_usage"
|
|
||||||
subFieldMonthlyUsage = "monthly_usage"
|
|
||||||
subFieldVersion = "version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 错误定义
|
// 错误定义
|
||||||
@@ -38,35 +18,6 @@ var (
|
|||||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||||
)
|
)
|
||||||
|
|
||||||
// 预编译的Lua脚本
|
|
||||||
var (
|
|
||||||
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
|
|
||||||
deductBalanceScript = redis.NewScript(`
|
|
||||||
local current = redis.call('GET', KEYS[1])
|
|
||||||
if current == false then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
|
||||||
redis.call('SET', KEYS[1], newVal)
|
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
|
|
||||||
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
|
|
||||||
updateSubUsageScript = redis.NewScript(`
|
|
||||||
local exists = redis.call('EXISTS', KEYS[1])
|
|
||||||
if exists == 0 then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
local cost = tonumber(ARGV[1])
|
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||||
type subscriptionCacheData struct {
|
type subscriptionCacheData struct {
|
||||||
Status string
|
Status string
|
||||||
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
|
|||||||
// BillingCacheService 计费缓存服务
|
// BillingCacheService 计费缓存服务
|
||||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||||
type BillingCacheService struct {
|
type BillingCacheService struct {
|
||||||
rdb *redis.Client
|
cache ports.BillingCache
|
||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
subRepo ports.UserSubscriptionRepository
|
subRepo ports.UserSubscriptionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBillingCacheService 创建计费缓存服务
|
// NewBillingCacheService 创建计费缓存服务
|
||||||
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||||
return &BillingCacheService{
|
return &BillingCacheService{
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
subRepo: subRepo,
|
subRepo: subRepo,
|
||||||
}
|
}
|
||||||
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, su
|
|||||||
|
|
||||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
// Redis不可用,直接查询数据库
|
// Redis不可用,直接查询数据库
|
||||||
return s.getUserBalanceFromDB(ctx, userID)
|
return s.getUserBalanceFromDB(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
|
||||||
|
|
||||||
// 尝试从缓存读取
|
// 尝试从缓存读取
|
||||||
val, err := s.rdb.Get(ctx, key).Result()
|
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
balance, parseErr := strconv.ParseFloat(val, 64)
|
return balance, nil
|
||||||
if parseErr == nil {
|
|
||||||
return balance, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存未命中或解析错误,从数据库读取
|
// 缓存未命中,从数据库读取
|
||||||
balance, err := s.getUserBalanceFromDB(ctx, userID)
|
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
|
|||||||
|
|
||||||
// setBalanceCache 设置余额缓存
|
// setBalanceCache 设置余额缓存
|
||||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||||
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
|
|
||||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
|
||||||
|
|
||||||
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
|
|
||||||
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
|
||||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateUserBalance 失效用户余额缓存
|
// InvalidateUserBalance 失效用户余额缓存
|
||||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
|
||||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
|
||||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
|
|||||||
|
|
||||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
|
||||||
|
|
||||||
// 尝试从缓存读取
|
// 尝试从缓存读取
|
||||||
result, err := s.rdb.HGetAll(ctx, key).Result()
|
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
if err == nil && len(result) > 0 {
|
if err == nil && cacheData != nil {
|
||||||
data, parseErr := s.parseSubscriptionCache(result)
|
return s.convertFromPortsData(cacheData), nil
|
||||||
if parseErr == nil {
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存未命中,从数据库读取
|
// 缓存未命中,从数据库读取
|
||||||
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
|||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
|
||||||
|
return &subscriptionCacheData{
|
||||||
|
Status: data.Status,
|
||||||
|
ExpiresAt: data.ExpiresAt,
|
||||||
|
DailyUsage: data.DailyUsage,
|
||||||
|
WeeklyUsage: data.WeeklyUsage,
|
||||||
|
MonthlyUsage: data.MonthlyUsage,
|
||||||
|
Version: data.Version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
|
||||||
|
return &ports.SubscriptionCacheData{
|
||||||
|
Status: data.Status,
|
||||||
|
ExpiresAt: data.ExpiresAt,
|
||||||
|
DailyUsage: data.DailyUsage,
|
||||||
|
WeeklyUsage: data.WeeklyUsage,
|
||||||
|
MonthlyUsage: data.MonthlyUsage,
|
||||||
|
Version: data.Version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||||
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseSubscriptionCache 解析订阅缓存数据
|
|
||||||
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
|
|
||||||
result := &subscriptionCacheData{}
|
|
||||||
|
|
||||||
result.Status = data[subFieldStatus]
|
|
||||||
if result.Status == "" {
|
|
||||||
return nil, errors.New("invalid cache: missing status")
|
|
||||||
}
|
|
||||||
|
|
||||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
|
||||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
|
||||||
if err == nil {
|
|
||||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
|
||||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
|
||||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
|
||||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
if versionStr, ok := data[subFieldVersion]; ok {
|
|
||||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSubscriptionCache 设置订阅缓存
|
// setSubscriptionCache 设置订阅缓存
|
||||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||||
if s.rdb == nil || data == nil {
|
if s.cache == nil || data == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
|
||||||
|
|
||||||
fields := map[string]interface{}{
|
|
||||||
subFieldStatus: data.Status,
|
|
||||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
|
||||||
subFieldDailyUsage: data.DailyUsage,
|
|
||||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
|
||||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
|
||||||
subFieldVersion: data.Version,
|
|
||||||
}
|
|
||||||
|
|
||||||
pipe := s.rdb.Pipeline()
|
|
||||||
pipe.HSet(ctx, key, fields)
|
|
||||||
pipe.Expire(ctx, key, billingCacheTTL)
|
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
|
||||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
|
||||||
|
|
||||||
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
|
|
||||||
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
|
||||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateSubscription 失效指定订阅缓存
|
// InvalidateSubscription 失效指定订阅缓存
|
||||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
|
||||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
|
||||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,22 +2,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"sub2api/internal/service/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Redis key prefixes
|
|
||||||
accountConcurrencyKey = "concurrency:account:"
|
|
||||||
userConcurrencyKey = "concurrency:user:"
|
|
||||||
userWaitCountKey = "concurrency:wait:"
|
|
||||||
|
|
||||||
// TTL for concurrency keys (auto-release safety net)
|
|
||||||
concurrencyKeyTTL = 10 * time.Minute
|
|
||||||
|
|
||||||
// Wait polling interval
|
// Wait polling interval
|
||||||
waitPollInterval = 100 * time.Millisecond
|
waitPollInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
@@ -28,70 +19,14 @@ const (
|
|||||||
defaultExtraWaitSlots = 20
|
defaultExtraWaitSlots = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// Pre-compiled Lua scripts for better performance
|
|
||||||
var (
|
|
||||||
// acquireScript: increment counter if below max, return 1 if successful
|
|
||||||
acquireScript = redis.NewScript(`
|
|
||||||
local current = redis.call('GET', KEYS[1])
|
|
||||||
if current == false then
|
|
||||||
current = 0
|
|
||||||
else
|
|
||||||
current = tonumber(current)
|
|
||||||
end
|
|
||||||
if current < tonumber(ARGV[1]) then
|
|
||||||
redis.call('INCR', KEYS[1])
|
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
`)
|
|
||||||
|
|
||||||
// releaseScript: decrement counter, but don't go below 0
|
|
||||||
releaseScript = redis.NewScript(`
|
|
||||||
local current = redis.call('GET', KEYS[1])
|
|
||||||
if current ~= false and tonumber(current) > 0 then
|
|
||||||
redis.call('DECR', KEYS[1])
|
|
||||||
end
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
|
|
||||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
|
||||||
incrementWaitScript = redis.NewScript(`
|
|
||||||
local waitKey = KEYS[1]
|
|
||||||
local maxWait = tonumber(ARGV[1])
|
|
||||||
local ttl = tonumber(ARGV[2])
|
|
||||||
local current = redis.call('GET', waitKey)
|
|
||||||
if current == false then
|
|
||||||
current = 0
|
|
||||||
else
|
|
||||||
current = tonumber(current)
|
|
||||||
end
|
|
||||||
if current >= maxWait then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
redis.call('INCR', waitKey)
|
|
||||||
redis.call('EXPIRE', waitKey, ttl)
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
|
|
||||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
|
||||||
decrementWaitScript = redis.NewScript(`
|
|
||||||
local current = redis.call('GET', KEYS[1])
|
|
||||||
if current ~= false and tonumber(current) > 0 then
|
|
||||||
redis.call('DECR', KEYS[1])
|
|
||||||
end
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||||
type ConcurrencyService struct {
|
type ConcurrencyService struct {
|
||||||
rdb *redis.Client
|
cache ports.ConcurrencyCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConcurrencyService creates a new ConcurrencyService
|
// NewConcurrencyService creates a new ConcurrencyService
|
||||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||||
return &ConcurrencyService{rdb: rdb}
|
return &ConcurrencyService{cache: cache}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireResult represents the result of acquiring a concurrency slot
|
// AcquireResult represents the result of acquiring a concurrency slot
|
||||||
@@ -104,20 +39,6 @@ type AcquireResult struct {
|
|||||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||||
// Returns a release function that MUST be called when the request completes.
|
// Returns a release function that MUST be called when the request completes.
|
||||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
|
||||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
|
||||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
|
||||||
// Returns a release function that MUST be called when the request completes.
|
|
||||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
|
||||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
|
||||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
|
||||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
|
||||||
// If maxConcurrency is 0 or negative, no limit
|
// If maxConcurrency is 0 or negative, no limit
|
||||||
if maxConcurrency <= 0 {
|
if maxConcurrency <= 0 {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to acquire immediately
|
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
|||||||
if acquired {
|
if acquired {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
ReleaseFunc: s.makeReleaseFunc(key),
|
ReleaseFunc: func() {
|
||||||
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||||
|
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||||
|
}
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not acquired, return with Acquired=false
|
|
||||||
// The caller (gateway handler) will handle waiting with ping support
|
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: false,
|
Acquired: false,
|
||||||
ReleaseFunc: nil,
|
ReleaseFunc: nil,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tryAcquire attempts to increment the counter if below max
|
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||||
// Uses pre-compiled Lua script for atomicity and performance
|
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
// Returns a release function that MUST be called when the request completes.
|
||||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||||
|
// If maxConcurrency is 0 or negative, no limit
|
||||||
|
if maxConcurrency <= 0 {
|
||||||
|
return &AcquireResult{
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: func() {}, // no-op
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
return result == 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeReleaseFunc creates a function to release a concurrency slot
|
if acquired {
|
||||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
return &AcquireResult{
|
||||||
return func() {
|
Acquired: true,
|
||||||
// Use background context to ensure release even if original context is cancelled
|
ReleaseFunc: func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||||
// Log error but don't panic - TTL will eventually clean up
|
}
|
||||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
},
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
return &AcquireResult{
|
||||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
Acquired: false,
|
||||||
val, err := s.rdb.Get(ctx, key).Int()
|
ReleaseFunc: nil,
|
||||||
if err == redis.Nil {
|
}, nil
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return val, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountCurrentCount returns current concurrency count for an account
|
|
||||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
|
||||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
|
||||||
return s.GetCurrentCount(ctx, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserCurrentCount returns current concurrency count for a user
|
|
||||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
|
||||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
|
||||||
return s.GetCurrentCount(ctx, key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
|
|||||||
// Returns true if successful, false if the wait queue is full.
|
// Returns true if successful, false if the wait queue is full.
|
||||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
// Redis not available, allow request
|
// Redis not available, allow request
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// On error, allow the request to proceed (fail open)
|
// On error, allow the request to proceed (fail open)
|
||||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
return result == 1, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||||
// Should be called when a request completes or exits the wait queue.
|
// Should be called when a request completes or exits the wait queue.
|
||||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
|
||||||
// Use background context to ensure decrement even if original context is cancelled
|
// Use background context to ensure decrement even if original context is cancelled
|
||||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserWaitCount returns current wait queue count for a user
|
|
||||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
|
||||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
|
||||||
return s.GetCurrentCount(ctx, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||||
func CalculateMaxWait(userConcurrency int) int {
|
func CalculateMaxWait(userConcurrency int) int {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
@@ -13,8 +12,6 @@ import (
|
|||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -25,19 +22,11 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
verifyCodeKeyPrefix = "email_verify:"
|
|
||||||
verifyCodeTTL = 15 * time.Minute
|
verifyCodeTTL = 15 * time.Minute
|
||||||
verifyCodeCooldown = 1 * time.Minute
|
verifyCodeCooldown = 1 * time.Minute
|
||||||
maxVerifyCodeAttempts = 5
|
maxVerifyCodeAttempts = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
// verifyCodeData Redis 中存储的验证码数据
|
|
||||||
type verifyCodeData struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Attempts int `json:"attempts"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SmtpConfig SMTP配置
|
// SmtpConfig SMTP配置
|
||||||
type SmtpConfig struct {
|
type SmtpConfig struct {
|
||||||
Host string
|
Host string
|
||||||
@@ -52,14 +41,14 @@ type SmtpConfig struct {
|
|||||||
// EmailService 邮件服务
|
// EmailService 邮件服务
|
||||||
type EmailService struct {
|
type EmailService struct {
|
||||||
settingRepo ports.SettingRepository
|
settingRepo ports.SettingRepository
|
||||||
rdb *redis.Client
|
cache ports.EmailCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEmailService 创建邮件服务实例
|
// NewEmailService 创建邮件服务实例
|
||||||
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
|
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
|
||||||
return &EmailService{
|
return &EmailService{
|
||||||
settingRepo: settingRepo,
|
settingRepo: settingRepo,
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
|||||||
|
|
||||||
// SendVerifyCode 发送验证码邮件
|
// SendVerifyCode 发送验证码邮件
|
||||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||||
key := verifyCodeKeyPrefix + email
|
|
||||||
|
|
||||||
// 检查是否在冷却期内
|
// 检查是否在冷却期内
|
||||||
existing, err := s.getVerifyCodeData(ctx, key)
|
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||||
if err == nil && existing != nil {
|
if err == nil && existing != nil {
|
||||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||||
return ErrVerifyCodeTooFrequent
|
return ErrVerifyCodeTooFrequent
|
||||||
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 保存验证码到 Redis
|
// 保存验证码到 Redis
|
||||||
data := &verifyCodeData{
|
data := &ports.VerificationCodeData{
|
||||||
Code: code,
|
Code: code,
|
||||||
Attempts: 0,
|
Attempts: 0,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
|
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
return fmt.Errorf("save verify code: %w", err)
|
return fmt.Errorf("save verify code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
|||||||
|
|
||||||
// VerifyCode 验证验证码
|
// VerifyCode 验证验证码
|
||||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||||
key := verifyCodeKeyPrefix + email
|
data, err := s.cache.GetVerificationCode(ctx, email)
|
||||||
|
|
||||||
data, err := s.getVerifyCodeData(ctx, key)
|
|
||||||
if err != nil || data == nil {
|
if err != nil || data == nil {
|
||||||
return ErrInvalidVerifyCode
|
return ErrInvalidVerifyCode
|
||||||
}
|
}
|
||||||
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
// 验证码不匹配
|
// 验证码不匹配
|
||||||
if data.Code != code {
|
if data.Code != code {
|
||||||
data.Attempts++
|
data.Attempts++
|
||||||
_ = s.setVerifyCodeData(ctx, key, data)
|
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||||
if data.Attempts >= maxVerifyCodeAttempts {
|
if data.Attempts >= maxVerifyCodeAttempts {
|
||||||
return ErrVerifyCodeMaxAttempts
|
return ErrVerifyCodeMaxAttempts
|
||||||
}
|
}
|
||||||
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证成功,删除验证码
|
// 验证成功,删除验证码
|
||||||
s.rdb.Del(ctx, key)
|
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getVerifyCodeData 从 Redis 获取验证码数据
|
|
||||||
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
|
|
||||||
val, err := s.rdb.Get(ctx, key).Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var data verifyCodeData
|
|
||||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setVerifyCodeData 保存验证码数据到 Redis
|
|
||||||
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
|
|
||||||
val, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||||
return fmt.Sprintf(`
|
return fmt.Sprintf(`
|
||||||
|
|||||||
@@ -24,13 +24,11 @@ import (
|
|||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||||
stickySessionPrefix = "sticky_session:"
|
|
||||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
||||||
)
|
)
|
||||||
@@ -82,7 +80,7 @@ type GatewayService struct {
|
|||||||
usageLogRepo ports.UsageLogRepository
|
usageLogRepo ports.UsageLogRepository
|
||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
userSubRepo ports.UserSubscriptionRepository
|
userSubRepo ports.UserSubscriptionRepository
|
||||||
rdb *redis.Client
|
cache ports.GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
@@ -98,7 +96,7 @@ func NewGatewayService(
|
|||||||
usageLogRepo ports.UsageLogRepository,
|
usageLogRepo ports.UsageLogRepository,
|
||||||
userRepo ports.UserRepository,
|
userRepo ports.UserRepository,
|
||||||
userSubRepo ports.UserSubscriptionRepository,
|
userSubRepo ports.UserSubscriptionRepository,
|
||||||
rdb *redis.Client,
|
cache ports.GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
oauthService *OAuthService,
|
oauthService *OAuthService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
@@ -124,7 +122,7 @@ func NewGatewayService(
|
|||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
@@ -290,14 +288,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
|||||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||||
// 同时检查模型支持
|
// 同时检查模型支持
|
||||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// 续期粘性会话
|
// 续期粘性会话
|
||||||
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
|
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -347,7 +345,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
|
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
@@ -526,7 +524,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuth账号:应用统一指纹
|
// OAuth账号:应用统一指纹
|
||||||
var fingerprint *Fingerprint
|
var fingerprint *ports.Fingerprint
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && s.identityService != nil {
|
||||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
|
|||||||
@@ -11,15 +11,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// Redis key prefix
|
|
||||||
identityFingerprintKey = "identity:fingerprint:"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
@@ -29,20 +24,8 @@ var (
|
|||||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Fingerprint 存储的指纹数据结构
|
|
||||||
type Fingerprint struct {
|
|
||||||
ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成)
|
|
||||||
UserAgent string `json:"user_agent"` // User-Agent
|
|
||||||
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
|
|
||||||
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
|
|
||||||
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
|
|
||||||
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
|
|
||||||
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
|
|
||||||
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
|
|
||||||
}
|
|
||||||
|
|
||||||
// 默认指纹值(当客户端未提供时使用)
|
// 默认指纹值(当客户端未提供时使用)
|
||||||
var defaultFingerprint = Fingerprint{
|
var defaultFingerprint = ports.Fingerprint{
|
||||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||||
StainlessLang: "js",
|
StainlessLang: "js",
|
||||||
StainlessPackageVersion: "0.52.0",
|
StainlessPackageVersion: "0.52.0",
|
||||||
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
|
|||||||
|
|
||||||
// IdentityService 管理OAuth账号的请求身份指纹
|
// IdentityService 管理OAuth账号的请求身份指纹
|
||||||
type IdentityService struct {
|
type IdentityService struct {
|
||||||
rdb *redis.Client
|
cache ports.IdentityCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIdentityService 创建新的IdentityService
|
// NewIdentityService 创建新的IdentityService
|
||||||
func NewIdentityService(rdb *redis.Client) *IdentityService {
|
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
|
||||||
return &IdentityService{rdb: rdb}
|
return &IdentityService{cache: cache}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
|
||||||
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
|
// 尝试从缓存获取指纹
|
||||||
|
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||||
// 尝试从Redis获取缓存的指纹
|
if err == nil && cached != nil {
|
||||||
data, err := s.rdb.Get(ctx, key).Bytes()
|
// 检查客户端的user-agent是否是更新版本
|
||||||
if err == nil && len(data) > 0 {
|
clientUA := headers.Get("User-Agent")
|
||||||
// 缓存存在,解析指纹
|
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||||
var cached Fingerprint
|
// 更新user-agent
|
||||||
if err := json.Unmarshal(data, &cached); err == nil {
|
cached.UserAgent = clientUA
|
||||||
// 检查客户端的user-agent是否是更新版本
|
// 保存更新后的指纹
|
||||||
clientUA := headers.Get("User-Agent")
|
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||||
// 更新user-agent
|
|
||||||
cached.UserAgent = clientUA
|
|
||||||
// 保存更新后的指纹
|
|
||||||
if newData, err := json.Marshal(cached); err == nil {
|
|
||||||
s.rdb.Set(ctx, key, newData, 0) // 永不过期
|
|
||||||
}
|
|
||||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
|
||||||
}
|
|
||||||
return &cached, nil
|
|
||||||
}
|
}
|
||||||
|
return cached, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存不存在或解析失败,创建新指纹
|
// 缓存不存在或解析失败,创建新指纹
|
||||||
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
|||||||
// 生成随机ClientID
|
// 生成随机ClientID
|
||||||
fp.ClientID = generateClientID()
|
fp.ClientID = generateClientID()
|
||||||
|
|
||||||
// 保存到Redis(永不过期)
|
// 保存到缓存(永不过期)
|
||||||
if data, err := json.Marshal(fp); err == nil {
|
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||||
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
|
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||||
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createFingerprintFromHeaders 从请求头创建指纹
|
// createFingerprintFromHeaders 从请求头创建指纹
|
||||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
|
||||||
fp := &Fingerprint{}
|
fp := &ports.Fingerprint{}
|
||||||
|
|
||||||
// 获取User-Agent
|
// 获取User-Agent
|
||||||
if ua := headers.Get("User-Agent"); ua != "" {
|
if ua := headers.Get("User-Agent"); ua != "" {
|
||||||
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
|
||||||
if fp == nil {
|
if fp == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
16
backend/internal/service/ports/api_key_cache.go
Normal file
16
backend/internal/service/ports/api_key_cache.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApiKeyCache defines cache operations for API key service
|
||||||
|
type ApiKeyCache interface {
|
||||||
|
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
|
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||||
|
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||||
|
|
||||||
|
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||||
|
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||||
|
}
|
||||||
31
backend/internal/service/ports/billing_cache.go
Normal file
31
backend/internal/service/ports/billing_cache.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SubscriptionCacheData represents cached subscription data
|
||||||
|
type SubscriptionCacheData struct {
|
||||||
|
Status string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
DailyUsage float64
|
||||||
|
WeeklyUsage float64
|
||||||
|
MonthlyUsage float64
|
||||||
|
Version int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// BillingCache defines cache operations for billing service
|
||||||
|
type BillingCache interface {
|
||||||
|
// Balance operations
|
||||||
|
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||||
|
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||||
|
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||||
|
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||||
|
|
||||||
|
// Subscription operations
|
||||||
|
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||||
|
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||||
|
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||||
|
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||||
|
}
|
||||||
19
backend/internal/service/ports/concurrency_cache.go
Normal file
19
backend/internal/service/ports/concurrency_cache.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// ConcurrencyCache defines cache operations for concurrency service
|
||||||
|
type ConcurrencyCache interface {
|
||||||
|
// Slot management
|
||||||
|
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
|
||||||
|
ReleaseAccountSlot(ctx context.Context, accountID int64) error
|
||||||
|
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||||
|
|
||||||
|
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
|
||||||
|
ReleaseUserSlot(ctx context.Context, userID int64) error
|
||||||
|
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||||
|
|
||||||
|
// Wait queue
|
||||||
|
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||||
|
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||||
|
}
|
||||||
20
backend/internal/service/ports/email_cache.go
Normal file
20
backend/internal/service/ports/email_cache.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VerificationCodeData represents verification code data
|
||||||
|
type VerificationCodeData struct {
|
||||||
|
Code string
|
||||||
|
Attempts int
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailCache defines cache operations for email service
|
||||||
|
type EmailCache interface {
|
||||||
|
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||||
|
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||||
|
DeleteVerificationCode(ctx context.Context, email string) error
|
||||||
|
}
|
||||||
13
backend/internal/service/ports/gateway_cache.go
Normal file
13
backend/internal/service/ports/gateway_cache.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GatewayCache defines cache operations for gateway service
|
||||||
|
type GatewayCache interface {
|
||||||
|
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||||
|
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||||
|
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||||
|
}
|
||||||
21
backend/internal/service/ports/identity_cache.go
Normal file
21
backend/internal/service/ports/identity_cache.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Fingerprint represents account fingerprint data
|
||||||
|
type Fingerprint struct {
|
||||||
|
ClientID string
|
||||||
|
UserAgent string
|
||||||
|
StainlessLang string
|
||||||
|
StainlessPackageVersion string
|
||||||
|
StainlessOS string
|
||||||
|
StainlessArch string
|
||||||
|
StainlessRuntime string
|
||||||
|
StainlessRuntimeVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdentityCache defines cache operations for identity service
|
||||||
|
type IdentityCache interface {
|
||||||
|
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||||
|
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||||
|
}
|
||||||
15
backend/internal/service/ports/redeem_cache.go
Normal file
15
backend/internal/service/ports/redeem_cache.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedeemCache defines cache operations for redeem service
|
||||||
|
type RedeemCache interface {
|
||||||
|
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
|
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
|
||||||
|
|
||||||
|
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
|
||||||
|
ReleaseRedeemLock(ctx context.Context, code string) error
|
||||||
|
}
|
||||||
12
backend/internal/service/ports/update_cache.go
Normal file
12
backend/internal/service/ports/update_cache.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateCache defines cache operations for update service
|
||||||
|
type UpdateCache interface {
|
||||||
|
GetUpdateInfo(ctx context.Context) (string, error)
|
||||||
|
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
|
||||||
|
}
|
||||||
@@ -26,11 +26,9 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
|
redeemMaxErrorsPerHour = 20
|
||||||
redeemLockKeyPrefix = "redeem:lock:"
|
redeemRateLimitDuration = time.Hour
|
||||||
redeemMaxErrorsPerHour = 20
|
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||||
redeemRateLimitDuration = time.Hour
|
|
||||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateCodesRequest 生成兑换码请求
|
// GenerateCodesRequest 生成兑换码请求
|
||||||
@@ -53,7 +51,7 @@ type RedeemService struct {
|
|||||||
redeemRepo ports.RedeemCodeRepository
|
redeemRepo ports.RedeemCodeRepository
|
||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
subscriptionService *SubscriptionService
|
subscriptionService *SubscriptionService
|
||||||
rdb *redis.Client
|
cache ports.RedeemCache
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,14 +60,14 @@ func NewRedeemService(
|
|||||||
redeemRepo ports.RedeemCodeRepository,
|
redeemRepo ports.RedeemCodeRepository,
|
||||||
userRepo ports.UserRepository,
|
userRepo ports.UserRepository,
|
||||||
subscriptionService *SubscriptionService,
|
subscriptionService *SubscriptionService,
|
||||||
rdb *redis.Client,
|
cache ports.RedeemCache,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
) *RedeemService {
|
) *RedeemService {
|
||||||
return &RedeemService{
|
return &RedeemService{
|
||||||
redeemRepo: redeemRepo,
|
redeemRepo: redeemRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
subscriptionService: subscriptionService,
|
subscriptionService: subscriptionService,
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -140,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
|||||||
|
|
||||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||||||
|
|
||||||
count, err := s.rdb.Get(ctx, key).Int()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
// Redis 出错时不阻止用户操作
|
// Redis 出错时不阻止用户操作
|
||||||
return nil
|
return nil
|
||||||
@@ -161,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
|
|||||||
|
|
||||||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||||||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
|
||||||
|
|
||||||
pipe := s.rdb.Pipeline()
|
|
||||||
pipe.Incr(ctx, key)
|
|
||||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
|
||||||
_, _ = pipe.Exec(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||||||
// 返回 true 表示获取成功,false 表示锁已被占用
|
// 返回 true 表示获取成功,false 表示锁已被占用
|
||||||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return true // 无 Redis 时降级为不加锁
|
return true // 无 Redis 时降级为不加锁
|
||||||
}
|
}
|
||||||
|
|
||||||
key := redeemLockKeyPrefix + code
|
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
|
||||||
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||||||
return true
|
return true
|
||||||
@@ -191,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
|
|||||||
|
|
||||||
// releaseRedeemLock 释放兑换码的分布式锁
|
// releaseRedeemLock 释放兑换码的分布式锁
|
||||||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := redeemLockKeyPrefix + code
|
_ = s.cache.ReleaseRedeemLock(ctx, code)
|
||||||
s.rdb.Del(ctx, key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redeem 使用兑换码
|
// Redeem 使用兑换码
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"sub2api/internal/service/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -36,15 +36,15 @@ const (
|
|||||||
|
|
||||||
// UpdateService handles software updates
|
// UpdateService handles software updates
|
||||||
type UpdateService struct {
|
type UpdateService struct {
|
||||||
rdb *redis.Client
|
cache ports.UpdateCache
|
||||||
currentVersion string
|
currentVersion string
|
||||||
buildType string // "source" for manual builds, "release" for CI builds
|
buildType string // "source" for manual builds, "release" for CI builds
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUpdateService creates a new UpdateService
|
// NewUpdateService creates a new UpdateService
|
||||||
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
|
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
|
||||||
return &UpdateService{
|
return &UpdateService{
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
currentVersion: version,
|
currentVersion: version,
|
||||||
buildType: buildType,
|
buildType: buildType,
|
||||||
}
|
}
|
||||||
@@ -533,7 +533,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||||
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
|
data, err := s.cache.GetUpdateInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -573,7 +573,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(cacheData)
|
data, _ := json.Marshal(cacheData)
|
||||||
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
|
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compareVersions compares two semantic versions
|
// compareVersions compares two semantic versions
|
||||||
|
|||||||
Reference in New Issue
Block a user