Merge PR #8: refactor(backend): 添加 service 缓存端口

This commit is contained in:
shaw
2025-12-20 11:05:01 +08:00
27 changed files with 906 additions and 468 deletions

View File

@@ -41,7 +41,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingRepository := repository.NewSettingRepository(db) settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig) client := infrastructure.ProvideRedis(configConfig)
emailService := service.NewEmailService(settingRepository, client) emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileService := service.NewTurnstileService(settingService) turnstileService := service.NewTurnstileService(settingService)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
@@ -51,15 +52,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyRepository := repository.NewApiKeyRepository(db) apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db) groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db) userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig) apiKeyCache := repository.NewApiKeyCache(client)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db) usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository) usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db) redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository) billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService) redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
accountRepository := repository.NewAccountRepository(db) accountRepository := repository.NewAccountRepository(db)
@@ -81,14 +85,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingService, err := service.ProvidePricingService(configConfig) pricingService, err := service.ProvidePricingService(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client) identityCache := repository.NewIdentityCache(client)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) identityService := service.NewIdentityService(identityCache)
concurrencyService := service.NewConcurrencyService(client) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)

View File

@@ -6,6 +6,7 @@ import (
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil" "sub2api/internal/pkg/sysutil"
"sub2api/internal/repository"
"sub2api/internal/service" "sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -19,8 +20,9 @@ type SystemHandler struct {
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler { func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
updateCache := repository.NewUpdateCache(rdb)
return &SystemHandler{ return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType), updateSvc: service.NewUpdateService(updateCache, version, buildType),
} }
} }

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}

View File

@@ -0,0 +1,174 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,132 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
accountConcurrencyKeyPrefix = "concurrency:account:"
userConcurrencyKeyPrefix = "concurrency:user:"
waitQueueKeyPrefix = "concurrency:wait:"
concurrencyTTL = 5 * time.Minute
)
var (
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"context"
"encoding/json"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data ports.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -0,0 +1,47 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
)
type identityCache struct {
rdb *redis.Client
}
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
return &identityCache{rdb: rdb}
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var fp ports.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err
}
return &fp, nil
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := json.Marshal(fp)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}

View File

@@ -0,0 +1,49 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemRateLimitDuration = 24 * time.Hour
)
type redeemCache struct {
rdb *redis.Client
}
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
return &redeemCache{rdb: rdb}
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const updateCacheKey = "update:latest"
type updateCache struct {
rdb *redis.Client
}
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
return &updateCache{rdb: rdb}
}
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
return c.rdb.Get(ctx, updateCacheKey).Result()
}
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
}

View File

@@ -19,6 +19,16 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"), wire.Struct(new(Repositories), "*"),
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewConcurrencyCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
NewUpdateCache,
// Bind concrete repositories to service port interfaces // Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)), wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)), wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),

View File

@@ -27,9 +27,7 @@ var (
) )
const ( const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:" apiKeyMaxErrorsPerHour = 20
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
) )
// CreateApiKeyRequest 创建API Key请求 // CreateApiKeyRequest 创建API Key请求
@@ -52,7 +50,7 @@ type ApiKeyService struct {
userRepo ports.UserRepository userRepo ports.UserRepository
groupRepo ports.GroupRepository groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client cache ports.ApiKeyCache
cfg *config.Config cfg *config.Config
} }
@@ -62,7 +60,7 @@ func NewApiKeyService(
userRepo ports.UserRepository, userRepo ports.UserRepository,
groupRepo ports.GroupRepository, groupRepo ports.GroupRepository,
userSubRepo ports.UserSubscriptionRepository, userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client, cache ports.ApiKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *ApiKeyService {
return &ApiKeyService{ return &ApiKeyService{
@@ -70,7 +68,7 @@ func NewApiKeyService(
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
rdb: rdb, cache: cache,
cfg: cfg, cfg: cfg,
} }
} }
@@ -113,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) count, err := s.cache.GetCreateAttemptCount(ctx, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作 // Redis 出错时不阻止用户操作
return nil return nil
@@ -134,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) _ = s.cache.IncrementCreateAttemptCount(ctx, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
} }
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
@@ -273,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
} }
// 缓存到Redis可选TTL设置为5分钟 // 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil { if s.cache != nil {
// 这里可以序列化并缓存API Key // 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误 _ = cacheKey // 使用变量避免未使用错误
} }
@@ -326,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
if req.Status != nil { if req.Status != nil {
apiKey.Status = *req.Status apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存 // 如果状态改变清除Redis缓存
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
_ = s.rdb.Del(ctx, cacheKey)
} }
} }
@@ -355,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// 清除Redis缓存 // 清除Redis缓存
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
_ = s.rdb.Del(ctx, cacheKey)
} }
if err := s.apiKeyRepo.Delete(ctx, id); err != nil { if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
@@ -400,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// IncrementUsage 增加API Key使用次数可选用于统计 // IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil { if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err) return fmt.Errorf("increment usage: %w", err)
} }
// 设置24小时过期 // 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour) _ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
} }
return nil return nil
} }

View File

@@ -5,30 +5,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"strconv"
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// 订阅缓存Hash字段
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
) )
// 错误定义 // 错误定义
@@ -38,35 +18,6 @@ var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
) )
// 预编译的Lua脚本
var (
// deductBalanceScript: 扣减余额缓存key不存在则忽略
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateSubUsageScript: 更新订阅用量缓存key不存在则忽略
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct { type subscriptionCacheData struct {
Status string Status string
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
rdb *redis.Client cache ports.BillingCache
userRepo ports.UserRepository userRepo ports.UserRepository
subRepo ports.UserSubscriptionRepository subRepo ports.UserSubscriptionRepository
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService { func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{ return &BillingCacheService{
rdb: rdb, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
} }
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, su
// GetUserBalance 获取用户余额(优先从缓存读取) // GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) { func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.rdb == nil { if s.cache == nil {
// Redis不可用直接查询数据库 // Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID) return s.getUserBalanceFromDB(ctx, userID)
} }
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 尝试从缓存读取 // 尝试从缓存读取
val, err := s.rdb.Get(ctx, key).Result() balance, err := s.cache.GetUserBalance(ctx, userID)
if err == nil { if err == nil {
balance, parseErr := strconv.ParseFloat(val, 64) return balance, nil
if parseErr == nil {
return balance, nil
}
} }
// 缓存未命中或解析错误,从数据库读取 // 缓存未命中,从数据库读取
balance, err := s.getUserBalanceFromDB(ctx, userID) balance, err = s.getUserBalanceFromDB(ctx, userID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
// setBalanceCache 设置余额缓存 // setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) { func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err) log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
} }
} }
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存) // DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
return s.cache.DeductUserBalance(ctx, userID, amount)
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 使用预编译的Lua脚本原子性扣减如果key不存在则忽略
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
} }
// InvalidateUserBalance 失效用户余额缓存 // InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err) log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err return err
} }
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取) // GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.rdb == nil { if s.cache == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID) return s.getSubscriptionFromDB(ctx, userID, groupID)
} }
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 尝试从缓存读取 // 尝试从缓存读取
result, err := s.rdb.HGetAll(ctx, key).Result() cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
if err == nil && len(result) > 0 { if err == nil && cacheData != nil {
data, parseErr := s.parseSubscriptionCache(result) return s.convertFromPortsData(cacheData), nil
if parseErr == nil {
return data, nil
}
} }
// 缓存未命中,从数据库读取 // 缓存未命中,从数据库读取
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
return data, nil return data, nil
} }
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
return &subscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
return &ports.SubscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
// getSubscriptionFromDB 从数据库获取订阅数据 // getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
}, nil }, nil
} }
// parseSubscriptionCache 解析订阅缓存数据
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
result := &subscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
// setSubscriptionCache 设置订阅缓存 // setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) { func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.rdb == nil || data == nil { if s.cache == nil || data == nil {
return return
} }
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := s.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
if _, err := pipe.Exec(ctx); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
} }
} }
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存) // UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 使用预编译的Lua脚本原子性增加用量如果key不存在则忽略
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
} }
// InvalidateSubscription 失效指定订阅缓存 // InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err return err
} }

View File

@@ -2,22 +2,13 @@ package service
import ( import (
"context" "context"
"fmt"
"log" "log"
"time" "time"
"github.com/redis/go-redis/v9" "sub2api/internal/service/ports"
) )
const ( const (
// Redis key prefixes
accountConcurrencyKey = "concurrency:account:"
userConcurrencyKey = "concurrency:user:"
userWaitCountKey = "concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL = 10 * time.Minute
// Wait polling interval // Wait polling interval
waitPollInterval = 100 * time.Millisecond waitPollInterval = 100 * time.Millisecond
@@ -28,70 +19,14 @@ const (
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20
) )
// Pre-compiled Lua scripts for better performance
var (
// acquireScript: increment counter if below max, return 1 if successful
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
// releaseScript: decrement counter, but don't go below 0
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
// ConcurrencyService manages concurrent request limiting for accounts and users // ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct { type ConcurrencyService struct {
rdb *redis.Client cache ports.ConcurrencyCache
} }
// NewConcurrencyService creates a new ConcurrencyService // NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService { func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{rdb: rdb} return &ConcurrencyService{cache: cache}
} }
// AcquireResult represents the result of acquiring a concurrency slot // AcquireResult represents the result of acquiring a concurrency slot
@@ -104,20 +39,6 @@ type AcquireResult struct {
// If the account is at max concurrency, it waits until a slot is available or timeout. // If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes. // Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit // If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 { if maxConcurrency <= 0 {
return &AcquireResult{ return &AcquireResult{
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
}, nil }, nil
} }
// Try to acquire immediately acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
if acquired { if acquired {
return &AcquireResult{ return &AcquireResult{
Acquired: true, Acquired: true,
ReleaseFunc: s.makeReleaseFunc(key), ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
}
},
}, nil }, nil
} }
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return &AcquireResult{ return &AcquireResult{
Acquired: false, Acquired: false,
ReleaseFunc: nil, ReleaseFunc: nil,
}, nil }, nil
} }
// tryAcquire attempts to increment the counter if below max // AcquireUserSlot attempts to acquire a concurrency slot for a user.
// Uses pre-compiled Lua script for atomicity and performance // If the user is at max concurrency, it waits until a slot is available or timeout.
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) { // Returns a release function that MUST be called when the request completes.
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int() func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return false, fmt.Errorf("acquire slot failed: %w", err) return nil, err
} }
return result == 1, nil
}
// makeReleaseFunc creates a function to release a concurrency slot if acquired {
func (s *ConcurrencyService) makeReleaseFunc(key string) func() { return &AcquireResult{
return func() { Acquired: true,
// Use background context to ensure release even if original context is cancelled ReleaseFunc: func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil { log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
// Log error but don't panic - TTL will eventually clean up }
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err) },
} }, nil
} }
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring return &AcquireResult{
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) { Acquired: false,
val, err := s.rdb.Get(ctx, key).Int() ReleaseFunc: nil,
if err == redis.Nil { }, nil
return 0, nil
}
if err != nil {
return 0, err
}
return val, nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.GetCurrentCount(ctx, key)
}
// GetUserCurrentCount returns current concurrency count for a user
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.GetCurrentCount(ctx, key)
} }
// ============================================ // ============================================
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
// Returns true if successful, false if the wait queue is full. // Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots // maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.rdb == nil { if s.cache == nil {
// Redis not available, allow request // Redis not available, allow request
return true, nil return true, nil
} }
key := fmt.Sprintf("%s%d", userWaitCountKey, userID) result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
if err != nil { if err != nil {
// On error, allow the request to proceed (fail open) // On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err) log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil return true, nil
} }
return result == 1, nil return result, nil
} }
// DecrementWaitCount decrements the wait queue counter for a user. // DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue. // Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) { func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
// Use background context to ensure decrement even if original context is cancelled // Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil { if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err) log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
} }
} }
// GetUserWaitCount returns current wait queue count for a user
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
return s.GetCurrentCount(ctx, key)
}
// CalculateMaxWait calculates the maximum wait queue size for a user // CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots // maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int { func CalculateMaxWait(userConcurrency int) int {

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
@@ -13,8 +12,6 @@ import (
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9"
) )
var ( var (
@@ -25,19 +22,11 @@ var (
) )
const ( const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
) )
// verifyCodeData Redis 中存储的验证码数据
type verifyCodeData struct {
Code string `json:"code"`
Attempts int `json:"attempts"`
CreatedAt time.Time `json:"created_at"`
}
// SmtpConfig SMTP配置 // SmtpConfig SMTP配置
type SmtpConfig struct { type SmtpConfig struct {
Host string Host string
@@ -52,14 +41,14 @@ type SmtpConfig struct {
// EmailService 邮件服务 // EmailService 邮件服务
type EmailService struct { type EmailService struct {
settingRepo ports.SettingRepository settingRepo ports.SettingRepository
rdb *redis.Client cache ports.EmailCache
} }
// NewEmailService 创建邮件服务实例 // NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService { func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
return &EmailService{ return &EmailService{
settingRepo: settingRepo, settingRepo: settingRepo,
rdb: rdb, cache: cache,
} }
} }
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
// SendVerifyCode 发送验证码邮件 // SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error { func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
key := verifyCodeKeyPrefix + email
// 检查是否在冷却期内 // 检查是否在冷却期内
existing, err := s.getVerifyCodeData(ctx, key) existing, err := s.cache.GetVerificationCode(ctx, email)
if err == nil && existing != nil { if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown { if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent return ErrVerifyCodeTooFrequent
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
} }
// 保存验证码到 Redis // 保存验证码到 Redis
data := &verifyCodeData{ data := &ports.VerificationCodeData{
Code: code, Code: code,
Attempts: 0, Attempts: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if err := s.setVerifyCodeData(ctx, key, data); err != nil { if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err) return fmt.Errorf("save verify code: %w", err)
} }
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
// VerifyCode 验证验证码 // VerifyCode 验证验证码
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error { func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
key := verifyCodeKeyPrefix + email data, err := s.cache.GetVerificationCode(ctx, email)
data, err := s.getVerifyCodeData(ctx, key)
if err != nil || data == nil { if err != nil || data == nil {
return ErrInvalidVerifyCode return ErrInvalidVerifyCode
} }
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 // 验证码不匹配
if data.Code != code { if data.Code != code {
data.Attempts++ data.Attempts++
_ = s.setVerifyCodeData(ctx, key, data) _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
} }
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
} }
// 验证成功,删除验证码 // 验证成功,删除验证码
s.rdb.Del(ctx, key) _ = s.cache.DeleteVerificationCode(ctx, email)
return nil return nil
} }
// getVerifyCodeData 从 Redis 获取验证码数据
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data verifyCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
val, err := json.Marshal(data)
if err != nil {
return err
}
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容 // buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string { func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
return fmt.Sprintf(` return fmt.Sprintf(`

View File

@@ -24,13 +24,11 @@ import (
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
const ( const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionPrefix = "sticky_session:"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
) )
@@ -82,7 +80,7 @@ type GatewayService struct {
usageLogRepo ports.UsageLogRepository usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client cache ports.GatewayCache
cfg *config.Config cfg *config.Config
oauthService *OAuthService oauthService *OAuthService
billingService *BillingService billingService *BillingService
@@ -98,7 +96,7 @@ func NewGatewayService(
usageLogRepo ports.UsageLogRepository, usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository, userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository, userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client, cache ports.GatewayCache,
cfg *config.Config, cfg *config.Config,
oauthService *OAuthService, oauthService *OAuthService,
billingService *BillingService, billingService *BillingService,
@@ -124,7 +122,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
userRepo: userRepo, userRepo: userRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
rdb: rdb, cache: cache,
cfg: cfg, cfg: cfg,
oauthService: oauthService, oauthService: oauthService,
billingService: billingService, billingService: billingService,
@@ -290,14 +288,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64() accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中 // 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持 // 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话 // 续期粘性会话
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL) s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return account, nil return account, nil
} }
} }
@@ -347,7 +345,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" { if sessionHash != "" {
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL) s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
} }
return selected, nil return selected, nil
@@ -526,7 +524,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
// OAuth账号应用统一指纹 // OAuth账号应用统一指纹
var fingerprint *Fingerprint var fingerprint *ports.Fingerprint
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID // 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)

View File

@@ -11,15 +11,10 @@ import (
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9"
) )
const (
// Redis key prefix
identityFingerprintKey = "identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
@@ -29,20 +24,8 @@ var (
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
) )
// Fingerprint 存储的指纹数据结构
type Fingerprint struct {
ClientID string `json:"client_id"` // 64位hex客户端ID首次随机生成
UserAgent string `json:"user_agent"` // User-Agent
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用) // 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{ var defaultFingerprint = ports.Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)", UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js", StainlessLang: "js",
StainlessPackageVersion: "0.52.0", StainlessPackageVersion: "0.52.0",
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
// IdentityService 管理OAuth账号的请求身份指纹 // IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct { type IdentityService struct {
rdb *redis.Client cache ports.IdentityCache
} }
// NewIdentityService 创建新的IdentityService // NewIdentityService 创建新的IdentityService
func NewIdentityService(rdb *redis.Client) *IdentityService { func NewIdentityService(cache ports.IdentityCache) *IdentityService {
return &IdentityService{rdb: rdb} return &IdentityService{cache: cache}
} }
// GetOrCreateFingerprint 获取或创建账号的指纹 // GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新 // 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存 // 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) { func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
key := identityFingerprintKey + strconv.FormatInt(accountID, 10) // 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
// 尝试从Redis获取缓存的指纹 if err == nil && cached != nil {
data, err := s.rdb.Get(ctx, key).Bytes() // 检查客户端的user-agent是否是更新版本
if err == nil && len(data) > 0 { clientUA := headers.Get("User-Agent")
// 缓存存在,解析指纹 if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
var cached Fingerprint // 更新user-agent
if err := json.Unmarshal(data, &cached); err == nil { cached.UserAgent = clientUA
// 检查客户端的user-agent是否是更新版本 // 保存更新后的指纹
clientUA := headers.Get("User-Agent") _ = s.cache.SetFingerprint(ctx, accountID, cached)
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
if newData, err := json.Marshal(cached); err == nil {
s.rdb.Set(ctx, key, newData, 0) // 永不过期
}
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return &cached, nil
} }
return cached, nil
} }
// 缓存不存在或解析失败,创建新指纹 // 缓存不存在或解析失败,创建新指纹
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID // 生成随机ClientID
fp.ClientID = generateClientID() fp.ClientID = generateClientID()
// 保存到Redis(永不过期) // 保存到缓存(永不过期)
if data, err := json.Marshal(fp); err == nil { if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil { log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
} }
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
} }
// createFingerprintFromHeaders 从请求头创建指纹 // createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint { func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
fp := &Fingerprint{} fp := &ports.Fingerprint{}
// 获取User-Agent // 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" { if ua := headers.Get("User-Agent"); ua != "" {
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
} }
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头) // ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
if fp == nil { if fp == nil {
return return
} }

View File

@@ -0,0 +1,16 @@
package ports
import (
"context"
"time"
)
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}

View File

@@ -0,0 +1,31 @@
package ports
import (
"context"
"time"
)
// SubscriptionCacheData represents cached subscription data
type SubscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
GetUserBalance(ctx context.Context, userID int64) (float64, error)
SetUserBalance(ctx context.Context, userID int64, balance float64) error
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
InvalidateUserBalance(ctx context.Context, userID int64) error
// Subscription operations
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
}

View File

@@ -0,0 +1,19 @@
package ports
import "context"
// ConcurrencyCache defines cache operations for concurrency service
type ConcurrencyCache interface {
// Slot management
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,20 @@
package ports
import (
"context"
"time"
)
// VerificationCodeData represents verification code data
type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
}
// EmailCache defines cache operations for email service
type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
}

View File

@@ -0,0 +1,13 @@
package ports
import (
"context"
"time"
)
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}

View File

@@ -0,0 +1,21 @@
package ports
import "context"
// Fingerprint represents account fingerprint data
type Fingerprint struct {
ClientID string
UserAgent string
StainlessLang string
StainlessPackageVersion string
StainlessOS string
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
}
// IdentityCache defines cache operations for identity service
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
}

View File

@@ -0,0 +1,15 @@
package ports
import (
"context"
"time"
)
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
ReleaseRedeemLock(ctx context.Context, code string) error
}

View File

@@ -0,0 +1,12 @@
package ports
import (
"context"
"time"
)
// UpdateCache defines cache operations for update service
type UpdateCache interface {
GetUpdateInfo(ctx context.Context) (string, error)
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
}

View File

@@ -26,11 +26,9 @@ var (
) )
const ( const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:" redeemMaxErrorsPerHour = 20
redeemLockKeyPrefix = "redeem:lock:" redeemRateLimitDuration = time.Hour
redeemMaxErrorsPerHour = 20 redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
) )
// GenerateCodesRequest 生成兑换码请求 // GenerateCodesRequest 生成兑换码请求
@@ -53,7 +51,7 @@ type RedeemService struct {
redeemRepo ports.RedeemCodeRepository redeemRepo ports.RedeemCodeRepository
userRepo ports.UserRepository userRepo ports.UserRepository
subscriptionService *SubscriptionService subscriptionService *SubscriptionService
rdb *redis.Client cache ports.RedeemCache
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
} }
@@ -62,14 +60,14 @@ func NewRedeemService(
redeemRepo ports.RedeemCodeRepository, redeemRepo ports.RedeemCodeRepository,
userRepo ports.UserRepository, userRepo ports.UserRepository,
subscriptionService *SubscriptionService, subscriptionService *SubscriptionService,
rdb *redis.Client, cache ports.RedeemCache,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
) *RedeemService { ) *RedeemService {
return &RedeemService{ return &RedeemService{
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
userRepo: userRepo, userRepo: userRepo,
subscriptionService: subscriptionService, subscriptionService: subscriptionService,
rdb: rdb, cache: cache,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
} }
} }
@@ -140,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
// checkRedeemRateLimit 检查用户兑换错误次数是否超限 // checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作 // Redis 出错时不阻止用户操作
return nil return nil
@@ -161,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
// incrementRedeemErrorCount 增加用户兑换错误计数 // incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) { func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) _ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
} }
// acquireRedeemLock 尝试获取兑换码的分布式锁 // acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用 // 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool { func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil { if s.cache == nil {
return true // 无 Redis 时降级为不加锁 return true // 无 Redis 时降级为不加锁
} }
key := redeemLockKeyPrefix + code ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
if err != nil { if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查 // Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true return true
@@ -191,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
// releaseRedeemLock 释放兑换码的分布式锁 // releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := redeemLockKeyPrefix + code _ = s.cache.ReleaseRedeemLock(ctx, code)
s.rdb.Del(ctx, key)
} }
// Redeem 使用兑换码 // Redeem 使用兑换码

View File

@@ -18,7 +18,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/redis/go-redis/v9" "sub2api/internal/service/ports"
) )
const ( const (
@@ -36,15 +36,15 @@ const (
// UpdateService handles software updates // UpdateService handles software updates
type UpdateService struct { type UpdateService struct {
rdb *redis.Client cache ports.UpdateCache
currentVersion string currentVersion string
buildType string // "source" for manual builds, "release" for CI builds buildType string // "source" for manual builds, "release" for CI builds
} }
// NewUpdateService creates a new UpdateService // NewUpdateService creates a new UpdateService
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService { func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
return &UpdateService{ return &UpdateService{
rdb: rdb, cache: cache,
currentVersion: version, currentVersion: version,
buildType: buildType, buildType: buildType,
} }
@@ -533,7 +533,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
} }
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) { func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result() data, err := s.cache.GetUpdateInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -573,7 +573,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
} }
data, _ := json.Marshal(cacheData) data, _ := json.Marshal(cacheData)
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second) s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
} }
// compareVersions compares two semantic versions // compareVersions compares two semantic versions