feat(idempotency): 为关键写接口接入幂等并完善并发容错

This commit is contained in:
yangjianbo
2026-02-23 12:45:37 +08:00
parent 3b6584cc8d
commit 5fa45f3b8c
40 changed files with 4383 additions and 223 deletions

View File

@@ -74,6 +74,7 @@ func provideCleanup(
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService, subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService, usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
@@ -147,6 +148,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil

View File

@@ -168,7 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo) serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService) idempotencyRepository := repository.NewIdempotencyRepository(client, db)
systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig)
systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
@@ -191,7 +193,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService) totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -206,7 +210,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -243,6 +247,7 @@ func provideCleanup(
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService, subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService, usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
@@ -315,6 +320,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil

View File

@@ -0,0 +1,50 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// IdempotencyRecord 幂等请求记录表。
type IdempotencyRecord struct {
ent.Schema
}
func (IdempotencyRecord) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "idempotency_records"},
}
}
func (IdempotencyRecord) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (IdempotencyRecord) Fields() []ent.Field {
return []ent.Field{
field.String("scope").MaxLen(128),
field.String("idempotency_key_hash").MaxLen(64),
field.String("request_fingerprint").MaxLen(64),
field.String("status").MaxLen(32),
field.Int("response_status").Optional().Nillable(),
field.String("response_body").Optional().Nillable(),
field.String("error_reason").MaxLen(128).Optional().Nillable(),
field.Time("locked_until").Optional().Nillable(),
field.Time("expires_at"),
}
}
func (IdempotencyRecord) Indexes() []ent.Index {
return []ent.Index{
index.Fields("scope", "idempotency_key_hash").Unique(),
index.Fields("expires_at"),
index.Fields("status", "locked_until"),
}
}

View File

@@ -74,6 +74,7 @@ type Config struct {
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"` Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
} }
type LogConfig struct { type LogConfig struct {
@@ -137,6 +138,25 @@ type UpdateConfig struct {
ProxyURL string `mapstructure:"proxy_url"` ProxyURL string `mapstructure:"proxy_url"`
} }
type IdempotencyConfig struct {
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
ObserveOnly bool `mapstructure:"observe_only"`
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
// CleanupIntervalSeconds 过期记录清理周期(秒)。
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
// CleanupBatchSize 每次清理的最大记录数。
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
}
type LinuxDoConnectConfig struct { type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"` Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"` ClientID string `mapstructure:"client_id"`
@@ -1117,6 +1137,16 @@ func setDefaults() {
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Idempotency
viper.SetDefault("idempotency.observe_only", true)
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
viper.SetDefault("idempotency.cleanup_batch_size", 500)
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true) viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -1560,6 +1590,27 @@ func (c *Config) Validate() error {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
} }
} }
if c.Idempotency.DefaultTTLSeconds <= 0 {
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
}
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
}
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
}
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
}
if c.Idempotency.MaxStoredResponseLen <= 0 {
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
}
if c.Idempotency.CleanupIntervalSeconds <= 0 {
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
}
if c.Idempotency.CleanupBatchSize <= 0 {
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
}
if c.Gateway.MaxBodySize <= 0 { if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive") return fmt.Errorf("gateway.max_body_size must be positive")
} }

View File

@@ -75,6 +75,42 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
} }
} }
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = false, want true")
}
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
}
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
}
}
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = true, want false")
}
if cfg.Idempotency.DefaultTTLSeconds != 600 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) { func TestLoadSchedulingConfigFromEnv(t *testing.T) {
resetViperWithJWTSecret(t) resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")

View File

@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return return
} }
dataPayload := req.Data if err := validateDataHeader(req.Data); err != nil {
if err := validateDataHeader(dataPayload); err != nil {
response.BadRequest(c, err.Error()) response.BadRequest(c, err.Error())
return return
} }
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil { if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind skipDefaultGroupBind = *req.SkipDefaultGroupBind
} }
dataPayload := req.Data
result := DataImportResult{} result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil { if err != nil {
response.ErrorFrom(c, err) return result, err
return
} }
proxyKeyToID := make(map[string]int64, len(existingProxies)) proxyKeyToID := make(map[string]int64, len(existingProxies))
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID proxyKeyToID[key] = existingID
result.ProxyReused++ result.ProxyReused++
if normalizedStatus != "" { if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus, Status: normalizedStatus,
}) })
} }
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue continue
} }
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name), Name: defaultProxyName(item.Name),
Protocol: item.Protocol, Protocol: item.Protocol,
Host: item.Host, Host: item.Host,
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username, Username: item.Username,
Password: item.Password, Password: item.Password,
}) })
if err != nil { if createErr != nil {
result.ProxyFailed++ result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{ result.Errors = append(result.Errors, DataImportError{
Kind: "proxy", Kind: "proxy",
Name: item.Name, Name: item.Name,
ProxyKey: key, ProxyKey: key,
Message: err.Error(), Message: createErr.Error(),
}) })
continue continue
} }
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++ result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status { if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus, Status: normalizedStatus,
}) })
} }
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind, SkipDefaultGroupBind: skipDefaultGroupBind,
} }
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++ result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{ result.Errors = append(result.Errors, DataImportError{
Kind: "account", Kind: "account",
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++ result.AccountCreated++
} }
response.Success(c, result) return result, nil
} }
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {

View File

@@ -405,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查 // 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Name: req.Name, account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Notes: req.Notes, Name: req.Name,
Platform: req.Platform, Notes: req.Notes,
Type: req.Type, Platform: req.Platform,
Credentials: req.Credentials, Type: req.Type,
Extra: req.Extra, Credentials: req.Credentials,
ProxyID: req.ProxyID, Extra: req.Extra,
Concurrency: req.Concurrency, ProxyID: req.ProxyID,
Priority: req.Priority, Concurrency: req.Concurrency,
RateMultiplier: req.RateMultiplier, Priority: req.Priority,
GroupIDs: req.GroupIDs, RateMultiplier: req.RateMultiplier,
ExpiresAt: req.ExpiresAt, GroupIDs: req.GroupIDs,
AutoPauseOnExpired: req.AutoPauseOnExpired, ExpiresAt: req.ExpiresAt,
SkipMixedChannelCheck: skipCheck, AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
}) })
if err != nil { if err != nil {
// 检查是否为混合渠道错误 // 检查是否为混合渠道错误
@@ -440,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) {
return return
} }
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
} }
// Update handles updating an account // Update handles updating an account
@@ -838,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return return
} }
ctx := c.Request.Context() executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0 success := 0
failed := 0 failed := 0
results := make([]gin.H, 0, len(req.Accounts)) results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts { for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 { if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++ failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{ results = append(results, gin.H{
"name": item.Name, "name": item.Name,
"success": false, "id": account.ID,
"error": "rate_multiplier must be >= 0", "success": true,
}) })
continue
} }
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk return gin.H{
"success": success,
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ "failed": failed,
Name: item.Name, "results": results,
Notes: item.Notes, }, nil
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
}) })
} }

View File

@@ -0,0 +1,115 @@
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"strings" "strings"
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return return
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Name: strings.TrimSpace(req.Name), proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Protocol: strings.TrimSpace(req.Protocol), Name: strings.TrimSpace(req.Name),
Host: strings.TrimSpace(req.Host), Protocol: strings.TrimSpace(req.Protocol),
Port: req.Port, Host: strings.TrimSpace(req.Host),
Username: strings.TrimSpace(req.Username), Port: req.Port,
Password: strings.TrimSpace(req.Password), Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromService(proxy), nil
}) })
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
} }
// Update handles updating a proxy // Update handles updating a proxy

View File

@@ -2,6 +2,7 @@ package admin
import ( import (
"bytes" "bytes"
"context"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"strconv" "strconv"
@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return return
} }
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Count: req.Count, codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Type: req.Type, Count: req.Count,
Value: req.Value, Type: req.Type,
GroupID: req.GroupID, Value: req.Value,
ValidityDays: req.ValidityDays, GroupID: req.GroupID,
}) ValidityDays: req.ValidityDays,
if err != nil { })
response.ErrorFrom(c, err) if execErr != nil {
return return nil, execErr
} }
out := make([]dto.AdminRedeemCode, 0, len(codes)) out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes { for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
} }
response.Success(c, out) return out, nil
})
} }
// Delete handles deleting a redeem code // Delete handles deleting a redeem code

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return return
} }
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) idempotencyPayload := struct {
if err != nil { SubscriptionID int64 `json:"subscription_id"`
response.ErrorFrom(c, err) Body AdjustSubscriptionRequest `json:"body"`
return }{
SubscriptionID: subscriptionID,
Body: req,
} }
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
} }
// Revoke handles revoking a subscription // Revoke handles revoking a subscription

View File

@@ -1,11 +1,15 @@
package admin package admin
import ( import (
"context"
"net/http" "net/http"
"strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations // SystemHandler handles system-related operations
type SystemHandler struct { type SystemHandler struct {
updateSvc *service.UpdateService updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
} }
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{ return &SystemHandler{
updateSvc: updateSvc, updateSvc: updateSvc,
lockSvc: lockSvc,
} }
} }
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update // PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update // POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) { func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { operationID := buildSystemOperationID(c, "update")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Update completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// Rollback restores the previous version // Rollback restores the previous version
// POST /api/v1/admin/system/rollback // POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) { func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil { operationID := buildSystemOperationID(c, "rollback")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Rollback completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// RestartService restarts the systemd service // RestartService restarts the systemd service
// POST /api/v1/admin/system/restart // POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) { func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response operationID := buildSystemOperationID(c, "restart")
// This ensures the client receives the success response before the service restarts payload := gin.H{"operation_id": operationID}
go func() { executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
// Wait a moment to ensure the response is sent lock, release, err := h.acquireSystemLock(ctx, operationID)
time.Sleep(500 * time.Millisecond) if err != nil {
sysutil.RestartServiceAsync() return nil, err
}() }
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{ // Schedule service restart in background after sending response
"message": "Service restart initiated", // This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
}) })
} }
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
billingType = *filters.BillingType billingType = *filters.BillingType
} }
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", idempotencyPayload := struct {
subject.UserID, OperatorID int64 `json:"operator_id"`
filters.StartTime.Format(time.RFC3339), Body CreateUsageCleanupTaskRequest `json:"body"`
filters.EndTime.Format(time.RFC3339), }{
userID, OperatorID: subject.UserID,
apiKeyID, Body: req,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
} }
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
response.Success(c, dto.UsageCleanupTaskFromService(task)) if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
} }
// CancelCleanupTask handles canceling a usage cleanup task // CancelCleanupTask handles canceling a usage cleanup task

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"strings" "strings"
@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return return
} }
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) idempotencyPayload := struct {
if err != nil { UserID int64 `json:"user_id"`
response.ErrorFrom(c, err) Body UpdateBalanceRequest `json:"body"`
return }{
UserID: userID,
Body: req,
} }
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
response.Success(c, dto.UserFromServiceAdmin(user)) user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
} }
// GetUserAPIKeys handles getting user's API keys // GetUserAPIKeys handles getting user's API keys

View File

@@ -2,6 +2,7 @@
package handler package handler
import ( import (
"context"
"strconv" "strconv"
"time" "time"
@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil { if req.Quota != nil {
svcReq.Quota = *req.Quota svcReq.Quota = *req.Quota
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.APIKeyFromService(key)) executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
if err != nil {
return nil, err
}
return dto.APIKeyFromService(key), nil
})
} }
// Update handles updating an API key // Update handles updating an API key

View File

@@ -2,6 +2,7 @@
package dto package dto
import ( import (
"strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -542,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
for i := range r.Subscriptions { for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
} }
statuses := make(map[string]string, len(r.Statuses))
for userID, status := range r.Statuses {
statuses[strconv.FormatInt(userID, 10)] = status
}
return &BulkAssignResult{ return &BulkAssignResult{
SuccessCount: r.SuccessCount, SuccessCount: r.SuccessCount,
CreatedCount: r.CreatedCount,
ReusedCount: r.ReusedCount,
FailedCount: r.FailedCount, FailedCount: r.FailedCount,
Subscriptions: subs, Subscriptions: subs,
Errors: r.Errors, Errors: r.Errors,
Statuses: statuses,
} }
} }

View File

@@ -395,9 +395,12 @@ type AdminUserSubscription struct {
type BulkAssignResult struct { type BulkAssignResult struct {
SuccessCount int `json:"success_count"` SuccessCount int `json:"success_count"`
CreatedCount int `json:"created_count"`
ReusedCount int `json:"reused_count"`
FailedCount int `json:"failed_count"` FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"` Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"` Errors []string `json:"errors"`
Statuses map[string]string `json:"statuses,omitempty"`
} }
// PromoCode 注册优惠码 // PromoCode 注册优惠码

View File

@@ -0,0 +1,65 @@
package handler
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func executeUserIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
return
}
actorScope := "user:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
}
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package handler
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userStoreUnavailableRepoStub struct{}
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
type userMemoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
return &userMemoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func withUserSubject(userID int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
}
}
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(nil)
var executed int
router := gin.New()
router.Use(withUserSubject(1))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, executed)
}
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.Use(withUserSubject(2))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "k1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed)
}
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newUserMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.Use(withUserSubject(3))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(80 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-user-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); status1, _ = call() }()
go func() { defer wg.Done(); status2, _ = call() }()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load())
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -53,8 +53,8 @@ func ProvideAdminHandlers(
} }
// ProvideSystemHandler creates admin.SystemHandler with UpdateService // ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService) return admin.NewSystemHandler(updateService, lockService)
} }
// ProvideSettingHandler creates SettingHandler with version from BuildInfo // ProvideSettingHandler creates SettingHandler with version from BuildInfo
@@ -77,6 +77,8 @@ func ProvideHandlers(
soraGatewayHandler *SoraGatewayHandler, soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler, totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,

View File

@@ -0,0 +1,237 @@
package repository
import (
"context"
"database/sql"
"errors"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type idempotencyRepository struct {
sql sqlExecutor
}
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
return &idempotencyRepository{sql: sqlDB}
}
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
if record == nil {
return false, nil
}
query := `
INSERT INTO idempotency_records (
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
RETURNING id, created_at, updated_at
`
var createdAt time.Time
var updatedAt time.Time
err := scanSingleRow(ctx, r.sql, query, []any{
record.Scope,
record.IdempotencyKeyHash,
record.RequestFingerprint,
record.Status,
record.LockedUntil,
record.ExpiresAt,
}, &record.ID, &createdAt, &updatedAt)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
record.CreatedAt = createdAt
record.UpdatedAt = updatedAt
return true, nil
}
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
query := `
SELECT
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
response_body, error_reason, locked_until, expires_at, created_at, updated_at
FROM idempotency_records
WHERE scope = $1 AND idempotency_key_hash = $2
`
record := &service.IdempotencyRecord{}
var responseStatus sql.NullInt64
var responseBody sql.NullString
var errorReason sql.NullString
var lockedUntil sql.NullTime
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
&record.ID,
&record.Scope,
&record.IdempotencyKeyHash,
&record.RequestFingerprint,
&record.Status,
&responseStatus,
&responseBody,
&errorReason,
&lockedUntil,
&record.ExpiresAt,
&record.CreatedAt,
&record.UpdatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
if responseStatus.Valid {
v := int(responseStatus.Int64)
record.ResponseStatus = &v
}
if responseBody.Valid {
v := responseBody.String
record.ResponseBody = &v
}
if errorReason.Valid {
v := errorReason.String
record.ErrorReason = &v
}
if lockedUntil.Valid {
v := lockedUntil.Time
record.LockedUntil = &v
}
return record, nil
}
func (r *idempotencyRepository) TryReclaim(
ctx context.Context,
id int64,
fromStatus string,
now, newLockedUntil, newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET status = $2,
locked_until = $3,
error_reason = NULL,
updated_at = NOW(),
expires_at = $4
WHERE id = $1
AND status = $5
AND (locked_until IS NULL OR locked_until <= $6)
`
res, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusProcessing,
newLockedUntil,
newExpiresAt,
fromStatus,
now,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) ExtendProcessingLock(
ctx context.Context,
id int64,
requestFingerprint string,
newLockedUntil,
newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET locked_until = $2,
expires_at = $3,
updated_at = NOW()
WHERE id = $1
AND status = $4
AND request_fingerprint = $5
`
res, err := r.sql.ExecContext(
ctx,
query,
id,
newLockedUntil,
newExpiresAt,
service.IdempotencyStatusProcessing,
requestFingerprint,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
response_status = $3,
response_body = $4,
error_reason = NULL,
locked_until = NULL,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusSucceeded,
responseStatus,
responseBody,
expiresAt,
)
return err
}
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
error_reason = $3,
locked_until = $4,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusFailedRetryable,
errorReason,
lockedUntil,
expiresAt,
)
return err
}
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
if limit <= 0 {
limit = 500
}
query := `
WITH victims AS (
SELECT id
FROM idempotency_records
WHERE expires_at <= $1
ORDER BY expires_at ASC
LIMIT $2
)
DELETE FROM idempotency_records
WHERE id IN (SELECT id FROM victims)
`
res, err := r.sql.ExecContext(ctx, query, now, limit)
if err != nil {
return 0, err
}
return res.RowsAffected()
}

View File

@@ -0,0 +1,144 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-create"),
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash"),
RequestFingerprint: uniqueTestValue(t, "idem-fp"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NotZero(t, record.ID)
duplicate := &service.IdempotencyRecord{
Scope: record.Scope,
IdempotencyKeyHash: record.IdempotencyKeyHash,
RequestFingerprint: uniqueTestValue(t, "idem-fp-other"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err = repo.CreateProcessing(ctx, duplicate)
require.NoError(t, err)
require.False(t, owner, "same scope+key hash should be de-duplicated")
}
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-reclaim"),
RequestFingerprint: uniqueTestValue(t, "idem-fp-reclaim"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(-2*time.Second),
now.Add(24*time.Hour),
))
newLockedUntil := now.Add(20 * time.Second)
reclaimed, err := repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
newLockedUntil,
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
require.NotNil(t, got.LockedUntil)
require.True(t, got.LockedUntil.After(now))
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(20*time.Second),
now.Add(24*time.Hour),
))
reclaimed, err = repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
now.Add(40*time.Second),
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.False(t, reclaimed, "within lock window should not reclaim")
}
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-success"),
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-success"),
RequestFingerprint: uniqueTestValue(t, "idem-fp-success"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
require.NotNil(t, got.ResponseStatus)
require.Equal(t, 200, *got.ResponseStatus)
require.NotNil(t, got.ResponseBody)
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}
func ptrTime(v time.Time) *time.Time {
return &v
}

View File

@@ -60,6 +60,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository, NewAnnouncementRepository,
NewAnnouncementReadRepository, NewAnnouncementReadRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository, NewUsageCleanupRepository,
NewDashboardAggregationRepository, NewDashboardAggregationRepository,
NewSettingRepository, NewSettingRepository,

View File

@@ -35,6 +35,8 @@ var (
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
apiKeyLastUsedMinTouch = 30 * time.Second apiKeyLastUsedMinTouch = 30 * time.Second
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
apiKeyLastUsedFailBackoff = 5 * time.Second
) )
type APIKeyRepository interface { type APIKeyRepository interface {
@@ -129,7 +131,7 @@ type APIKeyService struct {
authCacheL1 *ristretto.Cache authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> time.Time lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group lastUsedTouchSF singleflight.Group
} }
@@ -574,7 +576,7 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
now := time.Now() now := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok { if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch { if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
return nil return nil
} }
} }
@@ -582,15 +584,16 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
latest := time.Now() latest := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok { if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch { if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
return nil, nil return nil, nil
} }
} }
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
return nil, fmt.Errorf("touch api key last used: %w", err) return nil, fmt.Errorf("touch api key last used: %w", err)
} }
s.lastUsedTouchL1.Store(keyID, latest) s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
return nil, nil return nil, nil
}) })
return err return err

View File

@@ -79,8 +79,27 @@ func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
require.ErrorContains(t, err, "touch api key last used") require.ErrorContains(t, err, "touch api key last used")
require.Equal(t, []int64{123}, repo.touchedIDs) require.Equal(t, []int64{123}, repo.touchedIDs)
_, ok := svc.lastUsedTouchL1.Load(int64(123)) cached, ok := svc.lastUsedTouchL1.Load(int64(123))
require.False(t, ok, "failed touch should not update debounce cache") require.True(t, ok, "failed touch should still update retry debounce cache")
_, isTime := cached.(time.Time)
require.True(t, isTime)
}
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
repo := &apiKeyRepoStub{
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
return errors.New("db write failed")
},
}
svc := &APIKeyService{apiKeyRepo: repo}
firstErr := svc.TouchLastUsed(context.Background(), 456)
require.Error(t, firstErr)
require.ErrorContains(t, firstErr, "touch api key last used")
secondErr := svc.TouchLastUsed(context.Background(), 456)
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
} }
type touchSingleflightRepo struct { type touchSingleflightRepo struct {

View File

@@ -0,0 +1,471 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
)
const (
IdempotencyStatusProcessing = "processing"
IdempotencyStatusSucceeded = "succeeded"
IdempotencyStatusFailedRetryable = "failed_retryable"
)
var (
ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required")
ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid")
ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload")
ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing")
ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window")
ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable")
ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload")
)
type IdempotencyRecord struct {
ID int64
Scope string
IdempotencyKeyHash string
RequestFingerprint string
Status string
ResponseStatus *int
ResponseBody *string
ErrorReason *string
LockedUntil *time.Time
ExpiresAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
type IdempotencyRepository interface {
CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error)
GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error)
TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error)
ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error)
MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error
MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error
DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error)
}
type IdempotencyConfig struct {
DefaultTTL time.Duration
SystemOperationTTL time.Duration
ProcessingTimeout time.Duration
FailedRetryBackoff time.Duration
MaxStoredResponseLen int
ObserveOnly bool
}
func DefaultIdempotencyConfig() IdempotencyConfig {
return IdempotencyConfig{
DefaultTTL: 24 * time.Hour,
SystemOperationTTL: 1 * time.Hour,
ProcessingTimeout: 30 * time.Second,
FailedRetryBackoff: 5 * time.Second,
MaxStoredResponseLen: 64 * 1024,
ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断
}
}
type IdempotencyExecuteOptions struct {
Scope string
ActorScope string
Method string
Route string
IdempotencyKey string
Payload any
TTL time.Duration
RequireKey bool
}
type IdempotencyExecuteResult struct {
Data any
Replayed bool
}
type IdempotencyCoordinator struct {
repo IdempotencyRepository
cfg IdempotencyConfig
}
var (
defaultIdempotencyMu sync.RWMutex
defaultIdempotencySvc *IdempotencyCoordinator
)
func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) {
defaultIdempotencyMu.Lock()
defaultIdempotencySvc = svc
defaultIdempotencyMu.Unlock()
}
func DefaultIdempotencyCoordinator() *IdempotencyCoordinator {
defaultIdempotencyMu.RLock()
defer defaultIdempotencyMu.RUnlock()
return defaultIdempotencySvc
}
func DefaultWriteIdempotencyTTL() time.Duration {
defaultTTL := DefaultIdempotencyConfig().DefaultTTL
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 {
return coordinator.cfg.DefaultTTL
}
return defaultTTL
}
func DefaultSystemOperationIdempotencyTTL() time.Duration {
defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 {
return coordinator.cfg.SystemOperationTTL
}
return defaultTTL
}
func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator {
return &IdempotencyCoordinator{
repo: repo,
cfg: cfg,
}
}
func NormalizeIdempotencyKey(raw string) (string, error) {
key := strings.TrimSpace(raw)
if key == "" {
return "", nil
}
if len(key) > 128 {
return "", ErrIdempotencyKeyInvalid
}
for _, r := range key {
if r < 33 || r > 126 {
return "", ErrIdempotencyKeyInvalid
}
}
return key, nil
}
func HashIdempotencyKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) {
if method == "" {
method = "POST"
}
if route == "" {
route = "/"
}
if actorScope == "" {
actorScope = "anonymous"
}
raw, err := json.Marshal(payload)
if err != nil {
return "", ErrIdempotencyInvalidPayload.WithCause(err)
}
sum := sha256.Sum256([]byte(
strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw),
))
return hex.EncodeToString(sum[:]), nil
}
func RetryAfterSecondsFromError(err error) int {
appErr := new(infraerrors.ApplicationError)
if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil {
return 0
}
v := strings.TrimSpace(appErr.Metadata["retry_after"])
if v == "" {
return 0
}
seconds, convErr := strconv.Atoi(v)
if convErr != nil || seconds <= 0 {
return 0
}
return seconds
}
func (c *IdempotencyCoordinator) Execute(
ctx context.Context,
opts IdempotencyExecuteOptions,
execute func(context.Context) (any, error),
) (*IdempotencyExecuteResult, error) {
if execute == nil {
return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil")
}
key, err := NormalizeIdempotencyKey(opts.IdempotencyKey)
if err != nil {
return nil, err
}
if key == "" {
if opts.RequireKey && !c.cfg.ObserveOnly {
return nil, ErrIdempotencyKeyRequired
}
data, execErr := execute(ctx)
if execErr != nil {
return nil, execErr
}
return &IdempotencyExecuteResult{Data: data}, nil
}
if c.repo == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil")
return nil, ErrIdempotencyStoreUnavail
}
if opts.Scope == "" {
return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required")
}
fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload)
if err != nil {
return nil, err
}
ttl := opts.TTL
if ttl <= 0 {
ttl = c.cfg.DefaultTTL
}
now := time.Now()
expiresAt := now.Add(ttl)
lockedUntil := now.Add(c.cfg.ProcessingTimeout)
keyHash := HashIdempotencyKey(key)
record := &IdempotencyRecord{
Scope: opts.Scope,
IdempotencyKeyHash: keyHash,
RequestFingerprint: fingerprint,
Status: IdempotencyStatusProcessing,
LockedUntil: &lockedUntil,
ExpiresAt: expiresAt,
}
owner, err := c.repo.CreateProcessing(ctx, record)
if err != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "create_processing",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
}
if owner {
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{
"claim_mode": "new",
})
}
if !owner {
existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
if getErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "get_existing",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
}
if existing == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "missing_existing",
})
return nil, ErrIdempotencyStoreUnavail
}
if existing.RequestFingerprint != fingerprint {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
return nil, ErrIdempotencyKeyConflict
}
reclaimedByExpired := false
if !existing.ExpiresAt.After(now) {
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt)
if reclaimErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{
"operation": "try_reclaim_expired",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if taken {
reclaimedByExpired = true
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{
"claim_mode": "expired_reclaim",
})
record.ID = existing.ID
} else {
latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
if latestErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "get_existing_after_expired_reclaim",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr)
}
if latest == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "missing_existing_after_expired_reclaim",
})
return nil, ErrIdempotencyStoreUnavail
}
if latest.RequestFingerprint != fingerprint {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
return nil, ErrIdempotencyKeyConflict
}
existing = latest
}
}
if !reclaimedByExpired {
switch existing.Status {
case IdempotencyStatusSucceeded:
data, parseErr := c.decodeStoredResponse(existing.ResponseBody)
if parseErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{
"operation": "decode_stored_response",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr)
}
recordIdempotencyReplay(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil)
return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil
case IdempotencyStatusProcessing:
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil)
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
case IdempotencyStatusFailedRetryable:
if existing.LockedUntil != nil && existing.LockedUntil.After(now) {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"})
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil)
return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now)
}
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt)
if reclaimErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{
"operation": "try_reclaim",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if !taken {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{
"conflict": "reclaim_race",
})
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
}
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{
"claim_mode": "reclaim",
})
record.ID = existing.ID
default:
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{
"status": existing.Status,
})
return nil, ErrIdempotencyKeyConflict
}
}
}
if record.ID == 0 {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "record_id_missing",
})
return nil, ErrIdempotencyStoreUnavail
}
execStart := time.Now()
defer func() {
recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil)
}()
data, execErr := execute(ctx)
if execErr != nil {
backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff)
reason := infraerrors.Reason(execErr)
if reason == "" {
reason = "EXECUTION_FAILED"
}
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{
"reason": reason,
})
if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "mark_failed_retryable",
})
}
return nil, execErr
}
storedBody, marshalErr := c.marshalStoredResponse(data)
if marshalErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "marshal_response",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr)
}
if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "mark_succeeded",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(markErr)
}
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil)
return &IdempotencyExecuteResult{Data: data}, nil
}
func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error {
if lockedUntil == nil {
return base
}
sec := int(lockedUntil.Sub(now).Seconds())
if sec <= 0 {
sec = 1
}
return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)})
}
func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) {
raw, err := json.Marshal(data)
if err != nil {
return "", err
}
redacted := logredact.RedactText(string(raw))
if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen {
redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)"
}
return redacted, nil
}
func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) {
if stored == nil || strings.TrimSpace(*stored) == "" {
return map[string]any{}, nil
}
var out any
if err := json.Unmarshal([]byte(*stored), &out); err != nil {
return nil, fmt.Errorf("decode stored response: %w", err)
}
return out, nil
}

View File

@@ -0,0 +1,91 @@
package service
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。
type IdempotencyCleanupService struct {
repo IdempotencyRepository
interval time.Duration
batch int
startOnce sync.Once
stopOnce sync.Once
stopCh chan struct{}
}
func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
interval := 60 * time.Second
batch := 500
if cfg != nil {
if cfg.Idempotency.CleanupIntervalSeconds > 0 {
interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second
}
if cfg.Idempotency.CleanupBatchSize > 0 {
batch = cfg.Idempotency.CleanupBatchSize
}
}
return &IdempotencyCleanupService{
repo: repo,
interval: interval,
batch: batch,
stopCh: make(chan struct{}),
}
}
func (s *IdempotencyCleanupService) Start() {
if s == nil || s.repo == nil {
return
}
s.startOnce.Do(func() {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch)
go s.runLoop()
})
}
func (s *IdempotencyCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped")
})
}
func (s *IdempotencyCleanupService) runLoop() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
// 启动后先清理一轮,防止重启后积压。
s.cleanupOnce()
for {
select {
case <-ticker.C:
s.cleanupOnce()
case <-s.stopCh:
return
}
}
}
func (s *IdempotencyCleanupService) cleanupOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch)
if err != nil {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err)
return
}
if deleted > 0 {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted)
}
}

View File

@@ -0,0 +1,69 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type idempotencyCleanupRepoStub struct {
deleteCalls int
lastLimit int
deleteErr error
}
func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, nil
}
func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return nil
}
func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) {
r.deleteCalls++
r.lastLimit = limit
if r.deleteErr != nil {
return 0, r.deleteErr
}
return 1, nil
}
func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) {
repo := &idempotencyCleanupRepoStub{}
cfg := &config.Config{
Idempotency: config.IdempotencyConfig{
CleanupIntervalSeconds: 7,
CleanupBatchSize: 321,
},
}
svc := NewIdempotencyCleanupService(repo, cfg)
require.Equal(t, 7*time.Second, svc.interval)
require.Equal(t, 321, svc.batch)
}
func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) {
repo := &idempotencyCleanupRepoStub{}
svc := NewIdempotencyCleanupService(repo, &config.Config{
Idempotency: config.IdempotencyConfig{
CleanupBatchSize: 99,
},
})
svc.cleanupOnce()
require.Equal(t, 1, repo.deleteCalls)
require.Equal(t, 99, repo.lastLimit)
}

View File

@@ -0,0 +1,171 @@
package service
import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。
type IdempotencyMetricsSnapshot struct {
ClaimTotal uint64 `json:"claim_total"`
ReplayTotal uint64 `json:"replay_total"`
ConflictTotal uint64 `json:"conflict_total"`
RetryBackoffTotal uint64 `json:"retry_backoff_total"`
ProcessingDurationCount uint64 `json:"processing_duration_count"`
ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"`
StoreUnavailableTotal uint64 `json:"store_unavailable_total"`
}
type idempotencyMetrics struct {
claimTotal atomic.Uint64
replayTotal atomic.Uint64
conflictTotal atomic.Uint64
retryBackoffTotal atomic.Uint64
processingDurationCount atomic.Uint64
processingDurationMicros atomic.Uint64
storeUnavailableTotal atomic.Uint64
}
var defaultIdempotencyMetrics idempotencyMetrics
// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。
func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot {
totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load()
return IdempotencyMetricsSnapshot{
ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(),
ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(),
ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(),
RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(),
ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(),
ProcessingDurationTotalMs: float64(totalMicros) / 1000.0,
StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(),
}
}
func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.claimTotal.Add(1)
logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.replayTotal.Add(1)
logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.conflictTotal.Add(1)
logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.retryBackoffTotal.Add(1)
logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) {
if duration < 0 {
duration = 0
}
defaultIdempotencyMetrics.processingDurationCount.Add(1)
defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds()))
logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs)
}
// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。
func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) {
defaultIdempotencyMetrics.storeUnavailableTotal.Add(1)
attrs := map[string]string{}
if strategy != "" {
attrs["strategy"] = strategy
}
logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs)
}
func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) {
var b strings.Builder
builderWriteString(&b, "[IdempotencyAudit]")
builderWriteString(&b, " endpoint=")
builderWriteString(&b, safeAuditField(endpoint))
builderWriteString(&b, " scope=")
builderWriteString(&b, safeAuditField(scope))
builderWriteString(&b, " key_hash=")
builderWriteString(&b, safeAuditField(keyHash))
builderWriteString(&b, " state_transition=")
builderWriteString(&b, safeAuditField(stateTransition))
builderWriteString(&b, " replayed=")
builderWriteString(&b, strconv.FormatBool(replayed))
if len(attrs) > 0 {
appendSortedAttrs(&b, attrs)
}
logger.LegacyPrintf("service.idempotency", "%s", b.String())
}
func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) {
var b strings.Builder
builderWriteString(&b, "[IdempotencyMetric]")
builderWriteString(&b, " name=")
builderWriteString(&b, safeAuditField(name))
builderWriteString(&b, " endpoint=")
builderWriteString(&b, safeAuditField(endpoint))
builderWriteString(&b, " scope=")
builderWriteString(&b, safeAuditField(scope))
builderWriteString(&b, " value=")
builderWriteString(&b, safeAuditField(value))
if len(attrs) > 0 {
appendSortedAttrs(&b, attrs)
}
logger.LegacyPrintf("service.idempotency", "%s", b.String())
}
func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) {
if len(attrs) == 0 {
return
}
keys := make([]string, 0, len(attrs))
for k := range attrs {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
builderWriteByte(builder, ' ')
builderWriteString(builder, k)
builderWriteByte(builder, '=')
builderWriteString(builder, safeAuditField(attrs[k]))
}
}
func safeAuditField(v string) string {
value := strings.TrimSpace(v)
if value == "" {
return "-"
}
// 日志按 key=value 输出,替换空白避免解析歧义。
value = strings.ReplaceAll(value, "\n", "_")
value = strings.ReplaceAll(value, "\r", "_")
value = strings.ReplaceAll(value, "\t", "_")
value = strings.ReplaceAll(value, " ", "_")
return value
}
func resetIdempotencyMetricsForTest() {
defaultIdempotencyMetrics.claimTotal.Store(0)
defaultIdempotencyMetrics.replayTotal.Store(0)
defaultIdempotencyMetrics.conflictTotal.Store(0)
defaultIdempotencyMetrics.retryBackoffTotal.Store(0)
defaultIdempotencyMetrics.processingDurationCount.Store(0)
defaultIdempotencyMetrics.processingDurationMicros.Store(0)
defaultIdempotencyMetrics.storeUnavailableTotal.Store(0)
}
func builderWriteString(builder *strings.Builder, value string) {
_, _ = builder.WriteString(value)
}
func builderWriteByte(builder *strings.Builder, value byte) {
_ = builder.WriteByte(value)
}

View File

@@ -0,0 +1,805 @@
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
type inMemoryIdempotencyRepo struct {
mu sync.Mutex
nextID int64
data map[string]*IdempotencyRecord
}
func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo {
return &inMemoryIdempotencyRepo{
nextID: 1,
data: make(map[string]*IdempotencyRecord),
}
}
func (r *inMemoryIdempotencyRepo) key(scope, hash string) string {
return scope + "|" + hash
}
func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
return &out
}
func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
rec := cloneRecord(record)
rec.ID = r.nextID
rec.CreatedAt = time.Now()
rec.UpdatedAt = rec.CreatedAt
r.nextID++
r.data[k] = rec
record.ID = rec.ID
record.CreatedAt = rec.CreatedAt
record.UpdatedAt = rec.UpdatedAt
return true, nil
}
func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return cloneRecord(r.data[r.key(scope, keyHash)]), nil
}
func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
rec.UpdatedAt = time.Now()
return true, nil
}
return false, nil
}
func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.UpdatedAt = time.Now()
return true, nil
}
return false, nil
}
func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.UpdatedAt = time.Now()
rec.ErrorReason = nil
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
return nil
}
return errors.New("record not found")
}
func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.UpdatedAt = time.Now()
rec.ErrorReason = &errorReason
return nil
}
return errors.New("record not found")
}
func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) {
r.mu.Lock()
defer r.mu.Unlock()
var deleted int64
for k, rec := range r.data {
if !rec.ExpiresAt.After(now) {
delete(r.data, k)
deleted++
}
}
return deleted, nil
}
func TestIdempotencyCoordinator_RequireKey(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.ObserveOnly = false
coordinator := NewIdempotencyCoordinator(repo, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "admin:1",
RequireKey: true,
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired))
}
func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(repo, cfg)
execCount := 0
exec := func(ctx context.Context) (any, error) {
execCount++
return map[string]any{"count": execCount}, nil
}
opts := IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-1",
Payload: map[string]any{"a": 1},
}
first, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.False(t, first.Replayed)
second, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.True(t, second.Replayed)
require.Equal(t, 1, execCount, "second request should replay without executing business logic")
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ClaimTotal)
require.Equal(t, uint64(1), metrics.ReplayTotal)
}
func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
opts := IdempotencyExecuteOptions{
Scope: "test.scope.expired",
Method: "POST",
Route: "/test/expired",
ActorScope: "user:99",
RequireKey: true,
IdempotencyKey: "expired-case",
Payload: map[string]any{"k": "v"},
}
execCount := 0
exec := func(ctx context.Context) (any, error) {
execCount++
return map[string]any{"count": execCount}, nil
}
first, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, first)
require.False(t, first.Replayed)
require.Equal(t, 1, execCount)
keyHash := HashIdempotencyKey(opts.IdempotencyKey)
repo.mu.Lock()
existing := repo.data[repo.key(opts.Scope, keyHash)]
require.NotNil(t, existing)
existing.ExpiresAt = time.Now().Add(-time.Second)
repo.mu.Unlock()
second, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, second)
require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again")
require.Equal(t, 2, execCount)
third, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, third)
require.True(t, third.Replayed)
payload, ok := third.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), payload["count"])
metrics := GetIdempotencyMetricsSnapshot()
require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2))
require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1))
}
func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(repo, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-2",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-2",
Payload: map[string]any{"a": 2},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict))
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ConflictTotal)
}
func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.FailedRetryBackoff = 2 * time.Second
coordinator := NewIdempotencyCoordinator(repo, cfg)
opts := IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-3",
Payload: map[string]any{"a": 1},
}
_, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error")
})
require.Error(t, err)
_, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff))
require.Greater(t, RetryAfterSecondsFromError(err), 0)
metrics := GetIdempotencyMetricsSnapshot()
require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2))
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1))
}
func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
coordinator := NewIdempotencyCoordinator(repo, cfg)
opts := IdempotencyExecuteOptions{
Scope: "test.scope.concurrent",
Method: "POST",
Route: "/test/concurrent",
ActorScope: "user:7",
RequireKey: true,
IdempotencyKey: "concurrent-case",
Payload: map[string]any{"v": 1},
}
var execCount int32
var wg sync.WaitGroup
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
atomic.AddInt32(&execCount, 1)
time.Sleep(80 * time.Millisecond)
return map[string]any{"ok": true}, nil
})
}()
}
wg.Wait()
replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
atomic.AddInt32(&execCount, 1)
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
require.True(t, replayed.Replayed)
require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once")
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ClaimTotal)
require.Equal(t, uint64(1), metrics.ReplayTotal)
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
}
type failingIdempotencyRepo struct{}
func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) {
resetIdempotencyMetricsForTest()
coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig())
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope.unavailable",
Method: "POST",
Route: "/test/unavailable",
ActorScope: "admin:1",
RequireKey: true,
IdempotencyKey: "case-unavailable",
Payload: map[string]any{"v": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1))
}
func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) {
SetDefaultIdempotencyCoordinator(nil)
require.Nil(t, DefaultIdempotencyCoordinator())
require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL())
require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL())
coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
DefaultTTL: 2 * time.Hour,
SystemOperationTTL: 15 * time.Minute,
ProcessingTimeout: 10 * time.Second,
FailedRetryBackoff: 3 * time.Second,
ObserveOnly: false,
})
SetDefaultIdempotencyCoordinator(coordinator)
t.Cleanup(func() {
SetDefaultIdempotencyCoordinator(nil)
})
require.Same(t, coordinator, DefaultIdempotencyCoordinator())
require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL())
require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL())
}
func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) {
key, err := NormalizeIdempotencyKey(" abc-123 ")
require.NoError(t, err)
require.Equal(t, "abc-123", key)
key, err = NormalizeIdempotencyKey("")
require.NoError(t, err)
require.Equal(t, "", key)
_, err = NormalizeIdempotencyKey(string(make([]byte, 129)))
require.Error(t, err)
_, err = NormalizeIdempotencyKey("bad\nkey")
require.Error(t, err)
fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1})
require.NoError(t, err)
require.NotEmpty(t, fp1)
fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1})
require.NoError(t, err)
require.Equal(t, fp1, fp2)
_, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err))
}
func TestRetryAfterSecondsFromErrorBranches(t *testing.T) {
require.Equal(t, 0, RetryAfterSecondsFromError(nil))
require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain")))
err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"})
require.Equal(t, 12, RetryAfterSecondsFromError(err))
err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"})
require.Equal(t, 0, RetryAfterSecondsFromError(err))
}
func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Payload: map[string]any{"a": 1},
}, nil)
require.Error(t, err)
require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err))
called := 0
result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
RequireKey: true,
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
called++
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
require.Equal(t, 1, called)
require.NotNil(t, result)
require.False(t, result.Replayed)
}
type noIDOwnerRepo struct{}
func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return true, nil
}
func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, nil
}
func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, nil
}
func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil }
func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil }
func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) {
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(nil, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
IdempotencyKey: "k2",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err))
coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-no-id",
IdempotencyKey: "k3",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
}
type conflictBranchRepo struct {
existing *IdempotencyRecord
tryReclaimErr error
tryReclaimOK bool
}
func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, nil
}
func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return cloneRecord(r.existing), nil
}
func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
if r.tryReclaimErr != nil {
return false, r.tryReclaimErr
}
return r.tryReclaimOK, nil
}
func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return nil
}
func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, nil
}
func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) {
now := time.Now()
fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1})
require.NoError(t, err)
badBody := "{bad-json"
repo := &conflictBranchRepo{
existing: &IdempotencyRecord{
ID: 1,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: IdempotencyStatusSucceeded,
ResponseBody: &badBody,
ExpiresAt: now.Add(time.Hour),
},
}
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 2,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: "unknown",
ExpiresAt: now.Add(time.Hour),
}
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 3,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: IdempotencyStatusFailedRetryable,
LockedUntil: ptrTime(now.Add(-time.Second)),
ExpiresAt: now.Add(time.Hour),
}
repo.tryReclaimErr = errors.New("reclaim down")
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.tryReclaimErr = nil
repo.tryReclaimOK = false
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err))
}
type markBehaviorRepo struct {
inMemoryIdempotencyRepo
failMarkSucceeded bool
failMarkFailed bool
}
func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
if r.failMarkSucceeded {
return errors.New("mark succeeded failed")
}
return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt)
}
func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
if r.failMarkFailed {
return errors.New("mark failed retryable failed")
}
return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt)
}
func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) {
repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()}
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
repo.failMarkSucceeded = true
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-success",
IdempotencyKey: "k1",
Method: "POST",
Route: "/ok",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.failMarkSucceeded = false
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-marshal",
IdempotencyKey: "k2",
Method: "POST",
Route: "/bad",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"bad": make(chan int)}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.failMarkFailed = true
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-fail",
IdempotencyKey: "k3",
Method: "POST",
Route: "/fail",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return nil, errors.New("plain failure")
})
require.Error(t, err)
require.Equal(t, "plain failure", err.Error())
}
func TestIdempotencyCoordinator_HelperBranches(t *testing.T) {
c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
DefaultTTL: time.Hour,
SystemOperationTTL: time.Hour,
ProcessingTimeout: time.Second,
FailedRetryBackoff: time.Second,
MaxStoredResponseLen: 12,
ObserveOnly: false,
})
// conflictWithRetryAfter without locked_until should return base error.
base := ErrIdempotencyInProgress
err := c.conflictWithRetryAfter(base, nil, time.Now())
require.Equal(t, infraerrors.Code(base), infraerrors.Code(err))
// marshalStoredResponse should truncate.
body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"})
require.NoError(t, err)
require.Contains(t, body, "...(truncated)")
// decodeStoredResponse empty and invalid json.
out, err := c.decodeStoredResponse(nil)
require.NoError(t, err)
_, ok := out.(map[string]any)
require.True(t, ok)
invalid := "{invalid"
_, err = c.decodeStoredResponse(&invalid)
require.Error(t, err)
}

View File

@@ -0,0 +1,389 @@
package service
import (
"context"
"strconv"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type groupRepoNoop struct{}
func (groupRepoNoop) Create(context.Context, *Group) error { panic("unexpected Create call") }
func (groupRepoNoop) GetByID(context.Context, int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (groupRepoNoop) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (groupRepoNoop) Update(context.Context, *Group) error { panic("unexpected Update call") }
func (groupRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
func (groupRepoNoop) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (groupRepoNoop) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (groupRepoNoop) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (groupRepoNoop) ListActive(context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (groupRepoNoop) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (groupRepoNoop) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (groupRepoNoop) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected UpdateSortOrders call")
}
type subscriptionGroupRepoStub struct {
groupRepoNoop
group *Group
}
func (s *subscriptionGroupRepoStub) GetByID(context.Context, int64) (*Group, error) {
return s.group, nil
}
type userSubRepoNoop struct{}
func (userSubRepoNoop) Create(context.Context, *UserSubscription) error {
panic("unexpected Create call")
}
func (userSubRepoNoop) GetByID(context.Context, int64) (*UserSubscription, error) {
panic("unexpected GetByID call")
}
func (userSubRepoNoop) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
panic("unexpected GetByUserIDAndGroupID call")
}
func (userSubRepoNoop) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
panic("unexpected GetActiveByUserIDAndGroupID call")
}
func (userSubRepoNoop) Update(context.Context, *UserSubscription) error {
panic("unexpected Update call")
}
func (userSubRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
func (userSubRepoNoop) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
panic("unexpected ListByUserID call")
}
func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
panic("unexpected ListActiveByUserID call")
}
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
panic("unexpected ExistsByUserIDAndGroupID call")
}
func (userSubRepoNoop) ExtendExpiry(context.Context, int64, time.Time) error {
panic("unexpected ExtendExpiry call")
}
func (userSubRepoNoop) UpdateStatus(context.Context, int64, string) error {
panic("unexpected UpdateStatus call")
}
func (userSubRepoNoop) UpdateNotes(context.Context, int64, string) error {
panic("unexpected UpdateNotes call")
}
func (userSubRepoNoop) ActivateWindows(context.Context, int64, time.Time) error {
panic("unexpected ActivateWindows call")
}
func (userSubRepoNoop) ResetDailyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetDailyUsage call")
}
func (userSubRepoNoop) ResetWeeklyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetWeeklyUsage call")
}
func (userSubRepoNoop) ResetMonthlyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetMonthlyUsage call")
}
func (userSubRepoNoop) IncrementUsage(context.Context, int64, float64) error {
panic("unexpected IncrementUsage call")
}
func (userSubRepoNoop) BatchUpdateExpiredStatus(context.Context) (int64, error) {
panic("unexpected BatchUpdateExpiredStatus call")
}
type subscriptionUserSubRepoStub struct {
userSubRepoNoop
nextID int64
byID map[int64]*UserSubscription
byUserGroup map[string]*UserSubscription
createCalls int
}
func newSubscriptionUserSubRepoStub() *subscriptionUserSubRepoStub {
return &subscriptionUserSubRepoStub{
nextID: 1,
byID: make(map[int64]*UserSubscription),
byUserGroup: make(map[string]*UserSubscription),
}
}
func (s *subscriptionUserSubRepoStub) key(userID, groupID int64) string {
return strconvFormatInt(userID) + ":" + strconvFormatInt(groupID)
}
func (s *subscriptionUserSubRepoStub) seed(sub *UserSubscription) {
if sub == nil {
return
}
cp := *sub
if cp.ID == 0 {
cp.ID = s.nextID
s.nextID++
}
s.byID[cp.ID] = &cp
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
}
func (s *subscriptionUserSubRepoStub) ExistsByUserIDAndGroupID(_ context.Context, userID, groupID int64) (bool, error) {
_, ok := s.byUserGroup[s.key(userID, groupID)]
return ok, nil
}
func (s *subscriptionUserSubRepoStub) GetByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
sub := s.byUserGroup[s.key(userID, groupID)]
if sub == nil {
return nil, ErrSubscriptionNotFound
}
cp := *sub
return &cp, nil
}
func (s *subscriptionUserSubRepoStub) Create(_ context.Context, sub *UserSubscription) error {
if sub == nil {
return nil
}
s.createCalls++
cp := *sub
if cp.ID == 0 {
cp.ID = s.nextID
s.nextID++
}
sub.ID = cp.ID
s.byID[cp.ID] = &cp
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
return nil
}
func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
sub := s.byID[id]
if sub == nil {
return nil, ErrSubscriptionNotFound
}
cp := *sub
return &cp, nil
}
func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
subRepo.seed(&UserSubscription{
ID: 10,
UserID: 1001,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "init",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 1001,
GroupID: 1,
ValidityDays: 30,
Notes: "init",
})
require.NoError(t, err)
require.Equal(t, int64(10), sub.ID)
require.Equal(t, 0, subRepo.createCalls, "reuse should not create new subscription")
}
func TestAssignSubscriptionConflictWhenSemanticsMismatch(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
subRepo.seed(&UserSubscription{
ID: 11,
UserID: 2001,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "old-note",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 2001,
GroupID: 1,
ValidityDays: 30,
Notes: "new-note",
})
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_ASSIGN_CONFLICT", infraerrorsReason(err))
require.Equal(t, 0, subRepo.createCalls, "conflict should not create or mutate existing subscription")
}
func TestBulkAssignSubscriptionCreatedReusedAndConflict(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
// user 1: 语义一致,可 reused
subRepo.seed(&UserSubscription{
ID: 21,
UserID: 1,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "same-note",
})
// user 3: 语义冲突(有效期不一致),应 failed
subRepo.seed(&UserSubscription{
ID: 23,
UserID: 3,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 60),
Notes: "same-note",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
result, err := svc.BulkAssignSubscription(context.Background(), &BulkAssignSubscriptionInput{
UserIDs: []int64{1, 2, 3},
GroupID: 1,
ValidityDays: 30,
AssignedBy: 9,
Notes: "same-note",
})
require.NoError(t, err)
require.Equal(t, 2, result.SuccessCount)
require.Equal(t, 1, result.CreatedCount)
require.Equal(t, 1, result.ReusedCount)
require.Equal(t, 1, result.FailedCount)
require.Equal(t, "reused", result.Statuses[1])
require.Equal(t, "created", result.Statuses[2])
require.Equal(t, "failed", result.Statuses[3])
require.Equal(t, 1, subRepo.createCalls)
}
func TestAssignSubscriptionKeepsWorkingWhenIdempotencyStoreUnavailable(t *testing.T) {
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
SetDefaultIdempotencyCoordinator(NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()))
t.Cleanup(func() {
SetDefaultIdempotencyCoordinator(nil)
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 9001,
GroupID: 1,
ValidityDays: 30,
Notes: "new",
})
require.NoError(t, err)
require.NotNil(t, sub)
require.Equal(t, 1, subRepo.createCalls, "semantic idempotent endpoint should not depend on idempotency store availability")
}
func TestNormalizeAssignValidityDays(t *testing.T) {
require.Equal(t, 30, normalizeAssignValidityDays(0))
require.Equal(t, 30, normalizeAssignValidityDays(-5))
require.Equal(t, MaxValidityDays, normalizeAssignValidityDays(MaxValidityDays+100))
require.Equal(t, 7, normalizeAssignValidityDays(7))
}
func TestDetectAssignSemanticConflictCases(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
base := &UserSubscription{
UserID: 1,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "same",
}
reason, conflict := detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
Notes: "same",
})
require.False(t, conflict)
require.Equal(t, "", reason)
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 60,
Notes: "same",
})
require.True(t, conflict)
require.Equal(t, "validity_days_mismatch", reason)
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
Notes: "other",
})
require.True(t, conflict)
require.Equal(t, "notes_mismatch", reason)
}
func TestAssignSubscriptionGroupTypeValidation(t *testing.T) {
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeStandard},
}
subRepo := newSubscriptionUserSubRepoStub()
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrGroupNotSubscriptionType), infraerrors.Code(err))
}
func strconvFormatInt(v int64) string {
return strconv.FormatInt(v, 10)
}
func infraerrorsReason(err error) string {
return infraerrors.Reason(err)
}

View File

@@ -6,6 +6,7 @@ import (
"log" "log"
"math/rand/v2" "math/rand/v2"
"strconv" "strconv"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -24,16 +25,17 @@ var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
const MaxValidityDays = 36500 const MaxValidityDays = 36500
var ( var (
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found") ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired") ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended") ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)") ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
) )
// SubscriptionService 订阅服务 // SubscriptionService 订阅服务
@@ -150,40 +152,10 @@ type AssignSubscriptionInput struct {
// AssignSubscription 分配订阅给用户(不允许重复分配) // AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
// 检查分组是否存在且为订阅类型 sub, _, err := s.assignSubscriptionWithReuse(ctx, input)
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil return sub, nil
} }
@@ -363,9 +335,12 @@ type BulkAssignSubscriptionInput struct {
// BulkAssignResult 批量分配结果 // BulkAssignResult 批量分配结果
type BulkAssignResult struct { type BulkAssignResult struct {
SuccessCount int SuccessCount int
CreatedCount int
ReusedCount int
FailedCount int FailedCount int
Subscriptions []UserSubscription Subscriptions []UserSubscription
Errors []string Errors []string
Statuses map[int64]string
} }
// BulkAssignSubscription 批量分配订阅 // BulkAssignSubscription 批量分配订阅
@@ -373,10 +348,11 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
result := &BulkAssignResult{ result := &BulkAssignResult{
Subscriptions: make([]UserSubscription, 0), Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0), Errors: make([]string, 0),
Statuses: make(map[int64]string),
} }
for _, userID := range input.UserIDs { for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{ sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{
UserID: userID, UserID: userID,
GroupID: input.GroupID, GroupID: input.GroupID,
ValidityDays: input.ValidityDays, ValidityDays: input.ValidityDays,
@@ -386,15 +362,105 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
if err != nil { if err != nil {
result.FailedCount++ result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err)) result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
result.Statuses[userID] = "failed"
} else { } else {
result.SuccessCount++ result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub) result.Subscriptions = append(result.Subscriptions, *sub)
if reused {
result.ReusedCount++
result.Statuses[userID] = "reused"
} else {
result.CreatedCount++
result.Statuses[userID] = "created"
}
} }
} }
return result, nil return result, nil
} }
func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, false, err
}
if exists {
sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if getErr != nil {
return nil, false, getErr
}
if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict {
return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{
"conflict_reason": conflictReason,
})
}
return sub, true, nil
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil
}
func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) {
if existing == nil || input == nil {
return "", false
}
normalizedDays := normalizeAssignValidityDays(input.ValidityDays)
if !existing.StartsAt.IsZero() {
expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays)
if expectedExpiresAt.After(MaxExpiresAt) {
expectedExpiresAt = MaxExpiresAt
}
if !existing.ExpiresAt.Equal(expectedExpiresAt) {
return "validity_days_mismatch", true
}
}
existingNotes := strings.TrimSpace(existing.Notes)
inputNotes := strings.TrimSpace(input.Notes)
if existingNotes != inputNotes {
return "notes_mismatch", true
}
return "", false
}
func normalizeAssignValidityDays(days int) int {
if days <= 0 {
days = 30
}
if days > MaxValidityDays {
days = MaxValidityDays
}
return days
}
// RevokeSubscription 撤销订阅 // RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error { func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存 // 先获取订阅信息用于失效缓存

View File

@@ -0,0 +1,214 @@
package service
import (
"context"
"fmt"
"strconv"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
systemOperationLockScope = "admin.system.operations.global_lock"
systemOperationLockKey = "global-system-operation-lock"
)
var (
ErrSystemOperationBusy = infraerrors.Conflict("SYSTEM_OPERATION_BUSY", "another system operation is in progress")
)
type SystemOperationLock struct {
recordID int64
operationID string
stopOnce sync.Once
stopCh chan struct{}
}
func (l *SystemOperationLock) OperationID() string {
if l == nil {
return ""
}
return l.operationID
}
type SystemOperationLockService struct {
repo IdempotencyRepository
lease time.Duration
renewInterval time.Duration
ttl time.Duration
}
func NewSystemOperationLockService(repo IdempotencyRepository, cfg IdempotencyConfig) *SystemOperationLockService {
lease := cfg.ProcessingTimeout
if lease <= 0 {
lease = 30 * time.Second
}
renewInterval := lease / 3
if renewInterval < time.Second {
renewInterval = time.Second
}
ttl := cfg.SystemOperationTTL
if ttl <= 0 {
ttl = time.Hour
}
return &SystemOperationLockService{
repo: repo,
lease: lease,
renewInterval: renewInterval,
ttl: ttl,
}
}
func (s *SystemOperationLockService) Acquire(ctx context.Context, operationID string) (*SystemOperationLock, error) {
if s == nil || s.repo == nil {
return nil, ErrIdempotencyStoreUnavail
}
if operationID == "" {
return nil, infraerrors.BadRequest("SYSTEM_OPERATION_ID_REQUIRED", "operation id is required")
}
now := time.Now()
expiresAt := now.Add(s.ttl)
lockedUntil := now.Add(s.lease)
keyHash := HashIdempotencyKey(systemOperationLockKey)
record := &IdempotencyRecord{
Scope: systemOperationLockScope,
IdempotencyKeyHash: keyHash,
RequestFingerprint: operationID,
Status: IdempotencyStatusProcessing,
LockedUntil: &lockedUntil,
ExpiresAt: expiresAt,
}
owner, err := s.repo.CreateProcessing(ctx, record)
if err != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
}
if !owner {
existing, getErr := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
if getErr != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
}
if existing == nil {
return nil, ErrIdempotencyStoreUnavail
}
if existing.Status == IdempotencyStatusProcessing && existing.LockedUntil != nil && existing.LockedUntil.After(now) {
return nil, s.busyError(existing.RequestFingerprint, existing.LockedUntil, now)
}
reclaimed, reclaimErr := s.repo.TryReclaim(
ctx,
existing.ID,
existing.Status,
now,
lockedUntil,
expiresAt,
)
if reclaimErr != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if !reclaimed {
latest, _ := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
if latest != nil {
return nil, s.busyError(latest.RequestFingerprint, latest.LockedUntil, now)
}
return nil, ErrSystemOperationBusy
}
record.ID = existing.ID
}
if record.ID == 0 {
return nil, ErrIdempotencyStoreUnavail
}
lock := &SystemOperationLock{
recordID: record.ID,
operationID: operationID,
stopCh: make(chan struct{}),
}
go s.renewLoop(lock)
return lock, nil
}
func (s *SystemOperationLockService) Release(ctx context.Context, lock *SystemOperationLock, succeeded bool, failureReason string) error {
if s == nil || s.repo == nil || lock == nil {
return nil
}
lock.stopOnce.Do(func() {
close(lock.stopCh)
})
if ctx == nil {
ctx = context.Background()
}
expiresAt := time.Now().Add(s.ttl)
if succeeded {
responseBody := fmt.Sprintf(`{"operation_id":"%s","released":true}`, lock.operationID)
return s.repo.MarkSucceeded(ctx, lock.recordID, 200, responseBody, expiresAt)
}
reason := failureReason
if reason == "" {
reason = "SYSTEM_OPERATION_FAILED"
}
return s.repo.MarkFailedRetryable(ctx, lock.recordID, reason, time.Now(), expiresAt)
}
func (s *SystemOperationLockService) renewLoop(lock *SystemOperationLock) {
ticker := time.NewTicker(s.renewInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ok, err := s.repo.ExtendProcessingLock(
ctx,
lock.recordID,
lock.operationID,
now.Add(s.lease),
now.Add(s.ttl),
)
cancel()
if err != nil {
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew failed operation_id=%s err=%v", lock.operationID, err)
// 瞬时故障不应导致续租协程退出,下一轮继续尝试续租。
continue
}
if !ok {
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew stopped operation_id=%s reason=ownership_lost", lock.operationID)
return
}
case <-lock.stopCh:
return
}
}
}
func (s *SystemOperationLockService) busyError(operationID string, lockedUntil *time.Time, now time.Time) error {
metadata := make(map[string]string)
if operationID != "" {
metadata["operation_id"] = operationID
}
if lockedUntil != nil {
sec := int(lockedUntil.Sub(now).Seconds())
if sec <= 0 {
sec = 1
}
metadata["retry_after"] = strconv.Itoa(sec)
}
if len(metadata) == 0 {
return ErrSystemOperationBusy
}
return ErrSystemOperationBusy.WithMetadata(metadata)
}

View File

@@ -0,0 +1,305 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestSystemOperationLockService_AcquireBusyAndRelease(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock1, err := svc.Acquire(context.Background(), "op-1")
require.NoError(t, err)
require.NotNil(t, lock1)
_, err = svc.Acquire(context.Background(), "op-2")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
appErr := infraerrors.FromError(err)
require.Equal(t, "op-1", appErr.Metadata["operation_id"])
require.NotEmpty(t, appErr.Metadata["retry_after"])
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
lock2, err := svc.Acquire(context.Background(), "op-2")
require.NoError(t, err)
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
func TestSystemOperationLockService_RenewLease(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 1200 * time.Millisecond,
})
lock, err := svc.Acquire(context.Background(), "op-renew")
require.NoError(t, err)
require.NotNil(t, lock)
defer func() {
_ = svc.Release(context.Background(), lock, true, "")
}()
keyHash := HashIdempotencyKey(systemOperationLockKey)
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, initial)
require.NotNil(t, initial.LockedUntil)
initialLockedUntil := *initial.LockedUntil
time.Sleep(1500 * time.Millisecond)
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, updated)
require.NotNil(t, updated.LockedUntil)
require.True(t, updated.LockedUntil.After(initialLockedUntil), "locked_until should be renewed while lock is held")
}
type flakySystemLockRenewRepo struct {
*inMemoryIdempotencyRepo
extendCalls int32
}
func (r *flakySystemLockRenewRepo) ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
call := atomic.AddInt32(&r.extendCalls, 1)
if call == 1 {
return false, errors.New("transient extend failure")
}
return r.inMemoryIdempotencyRepo.ExtendProcessingLock(ctx, id, requestFingerprint, newLockedUntil, newExpiresAt)
}
func TestSystemOperationLockService_RenewLeaseContinuesAfterTransientFailure(t *testing.T) {
repo := &flakySystemLockRenewRepo{inMemoryIdempotencyRepo: newInMemoryIdempotencyRepo()}
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 2400 * time.Millisecond,
})
lock, err := svc.Acquire(context.Background(), "op-renew-transient")
require.NoError(t, err)
require.NotNil(t, lock)
defer func() {
_ = svc.Release(context.Background(), lock, true, "")
}()
keyHash := HashIdempotencyKey(systemOperationLockKey)
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, initial)
require.NotNil(t, initial.LockedUntil)
initialLockedUntil := *initial.LockedUntil
// 首次续租失败后,下一轮应继续尝试并成功更新锁过期时间。
require.Eventually(t, func() bool {
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
if updated == nil || updated.LockedUntil == nil {
return false
}
return atomic.LoadInt32(&repo.extendCalls) >= 2 && updated.LockedUntil.After(initialLockedUntil)
}, 4*time.Second, 100*time.Millisecond, "renew loop should continue after transient error")
}
func TestSystemOperationLockService_SameOperationIDRetryWhileRunning(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock1, err := svc.Acquire(context.Background(), "op-same")
require.NoError(t, err)
require.NotNil(t, lock1)
_, err = svc.Acquire(context.Background(), "op-same")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
appErr := infraerrors.FromError(err)
require.Equal(t, "op-same", appErr.Metadata["operation_id"])
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
lock2, err := svc.Acquire(context.Background(), "op-same")
require.NoError(t, err)
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
func TestSystemOperationLockService_RecoverAfterLeaseExpired(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 300 * time.Millisecond,
})
lock1, err := svc.Acquire(context.Background(), "op-crashed")
require.NoError(t, err)
require.NotNil(t, lock1)
// 模拟实例异常:停止续租,不调用 Release。
lock1.stopOnce.Do(func() {
close(lock1.stopCh)
})
time.Sleep(450 * time.Millisecond)
lock2, err := svc.Acquire(context.Background(), "op-recovered")
require.NoError(t, err, "expired lease should allow a new operation to reclaim lock")
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
type systemLockRepoStub struct {
createOwner bool
createErr error
existing *IdempotencyRecord
getErr error
reclaimOK bool
reclaimErr error
markSuccErr error
markFailErr error
}
func (s *systemLockRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
if s.createErr != nil {
return false, s.createErr
}
return s.createOwner, nil
}
func (s *systemLockRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
if s.getErr != nil {
return nil, s.getErr
}
return cloneRecord(s.existing), nil
}
func (s *systemLockRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
if s.reclaimErr != nil {
return false, s.reclaimErr
}
return s.reclaimOK, nil
}
func (s *systemLockRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return true, nil
}
func (s *systemLockRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return s.markSuccErr
}
func (s *systemLockRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return s.markFailErr
}
func (s *systemLockRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, nil
}
func TestSystemOperationLockService_InputAndStoreErrorBranches(t *testing.T) {
var nilSvc *SystemOperationLockService
_, err := nilSvc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
svc := &SystemOperationLockService{repo: nil}
_, err = svc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
svc = NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err = svc.Acquire(context.Background(), "")
require.Error(t, err)
require.Equal(t, "SYSTEM_OPERATION_ID_REQUIRED", infraerrors.Reason(err))
badStore := &systemLockRepoStub{createErr: errors.New("db down")}
svc = NewSystemOperationLockService(badStore, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err = svc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
}
func TestSystemOperationLockService_ExistingNilAndReclaimBranches(t *testing.T) {
now := time.Now()
repo := &systemLockRepoStub{
createOwner: false,
}
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err := svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 1,
Scope: systemOperationLockScope,
IdempotencyKeyHash: HashIdempotencyKey(systemOperationLockKey),
RequestFingerprint: "other-op",
Status: IdempotencyStatusFailedRetryable,
LockedUntil: ptrTime(now.Add(-time.Second)),
ExpiresAt: now.Add(time.Hour),
}
repo.reclaimErr = errors.New("reclaim failed")
_, err = svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.reclaimErr = nil
repo.reclaimOK = false
_, err = svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
}
func TestSystemOperationLockService_ReleaseBranchesAndOperationID(t *testing.T) {
require.Equal(t, "", (*SystemOperationLock)(nil).OperationID())
svc := NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock, err := svc.Acquire(context.Background(), "op")
require.NoError(t, err)
require.NotNil(t, lock)
require.NoError(t, svc.Release(context.Background(), lock, false, ""))
require.NoError(t, svc.Release(context.Background(), lock, true, ""))
repo := &systemLockRepoStub{
createOwner: true,
markSuccErr: errors.New("mark succeeded failed"),
markFailErr: errors.New("mark failed failed"),
}
svc = NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock = &SystemOperationLock{recordID: 1, operationID: "op2", stopCh: make(chan struct{})}
require.Error(t, svc.Release(context.Background(), lock, true, ""))
lock = &SystemOperationLock{recordID: 1, operationID: "op3", stopCh: make(chan struct{})}
require.Error(t, svc.Release(context.Background(), lock, false, "BAD"))
var nilLockSvc *SystemOperationLockService
require.NoError(t, nilLockSvc.Release(context.Background(), nil, true, ""))
err = svc.busyError("", nil, time.Now())
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
}

View File

@@ -320,6 +320,10 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
return err return err
} }
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status) logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
if status == UsageCleanupStatusCanceled {
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent hit: task=%d operator=%d", taskID, canceledBy)
return nil
}
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
} }
@@ -329,6 +333,11 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
} }
if !ok { if !ok {
// 状态可能并发改变 // 状态可能并发改变
currentStatus, getErr := s.repo.GetTaskStatus(ctx, taskID)
if getErr == nil && currentStatus == UsageCleanupStatusCanceled {
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent race hit: task=%d operator=%d", taskID, canceledBy)
return nil
}
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
} }
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy) logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)

View File

@@ -644,6 +644,23 @@ func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err)) require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
} }
func TestUsageCleanupServiceCancelTaskAlreadyCanceledIsIdempotent(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusCanceled,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.NoError(t, err)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.cancelCalls, "already canceled should return success without extra cancel write")
}
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) { func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
shouldCancel := false shouldCancel := false
repo := &cleanupRepoStub{ repo := &cleanupRepoStub{

View File

@@ -225,6 +225,45 @@ func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Confi
return svc return svc
} }
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
idempotencyCfg := DefaultIdempotencyConfig()
if cfg != nil {
if cfg.Idempotency.DefaultTTLSeconds > 0 {
idempotencyCfg.DefaultTTL = time.Duration(cfg.Idempotency.DefaultTTLSeconds) * time.Second
}
if cfg.Idempotency.SystemOperationTTLSeconds > 0 {
idempotencyCfg.SystemOperationTTL = time.Duration(cfg.Idempotency.SystemOperationTTLSeconds) * time.Second
}
if cfg.Idempotency.ProcessingTimeoutSeconds > 0 {
idempotencyCfg.ProcessingTimeout = time.Duration(cfg.Idempotency.ProcessingTimeoutSeconds) * time.Second
}
if cfg.Idempotency.FailedRetryBackoffSeconds > 0 {
idempotencyCfg.FailedRetryBackoff = time.Duration(cfg.Idempotency.FailedRetryBackoffSeconds) * time.Second
}
if cfg.Idempotency.MaxStoredResponseLen > 0 {
idempotencyCfg.MaxStoredResponseLen = cfg.Idempotency.MaxStoredResponseLen
}
idempotencyCfg.ObserveOnly = cfg.Idempotency.ObserveOnly
}
return idempotencyCfg
}
func ProvideIdempotencyCoordinator(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCoordinator {
coordinator := NewIdempotencyCoordinator(repo, buildIdempotencyConfig(cfg))
SetDefaultIdempotencyCoordinator(coordinator)
return coordinator
}
func ProvideSystemOperationLockService(repo IdempotencyRepository, cfg *config.Config) *SystemOperationLockService {
return NewSystemOperationLockService(repo, buildIdempotencyConfig(cfg))
}
func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
svc := NewIdempotencyCleanupService(repo, cfg)
svc.Start()
return svc
}
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
func ProvideOpsScheduledReportService( func ProvideOpsScheduledReportService(
opsService *OpsService, opsService *OpsService,
@@ -318,4 +357,7 @@ var ProviderSet = wire.NewSet(
NewTotpService, NewTotpService,
NewErrorPassthroughService, NewErrorPassthroughService,
NewDigestSessionStore, NewDigestSessionStore,
ProvideIdempotencyCoordinator,
ProvideSystemOperationLockService,
ProvideIdempotencyCleanupService,
) )

View File

@@ -0,0 +1,27 @@
-- 幂等记录表:用于关键写接口的请求去重与结果重放
-- 幂等执行:可重复运行
CREATE TABLE IF NOT EXISTS idempotency_records (
id BIGSERIAL PRIMARY KEY,
scope VARCHAR(128) NOT NULL,
idempotency_key_hash VARCHAR(64) NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
status VARCHAR(32) NOT NULL,
response_status INTEGER,
response_body TEXT,
error_reason VARCHAR(128),
locked_until TIMESTAMPTZ,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_idempotency_records_scope_key
ON idempotency_records (scope, idempotency_key_hash);
CREATE INDEX IF NOT EXISTS idx_idempotency_records_expires_at
ON idempotency_records (expires_at);
CREATE INDEX IF NOT EXISTS idx_idempotency_records_status_locked_until
ON idempotency_records (status, locked_until);

View File

@@ -568,6 +568,30 @@ usage_cleanup:
# 单次任务最大执行时长(秒) # 单次任务最大执行时长(秒)
task_timeout_seconds: 1800 task_timeout_seconds: 1800
# =============================================================================
# HTTP 写接口幂等配置
# Idempotency Configuration
# =============================================================================
idempotency:
# Observe-only 模式:
# true: 观察期,不带 Idempotency-Key 仍放行(但会记录)
# false: 强制期,不带 Idempotency-Key 直接拒绝(仅对接入幂等保护的接口生效)
observe_only: true
# 关键写接口幂等记录 TTL
default_ttl_seconds: 86400
# 系统操作接口update/rollback/restart幂等记录 TTL
system_operation_ttl_seconds: 3600
# processing 锁超时(秒)
processing_timeout_seconds: 30
# 可重试失败退避窗口(秒)
failed_retry_backoff_seconds: 5
# 持久化响应体最大长度(字节)
max_stored_response_len: 65536
# 过期幂等记录清理周期(秒)
cleanup_interval_seconds: 60
# 每轮清理最大删除条数
cleanup_batch_size: 500
# ============================================================================= # =============================================================================
# Concurrency Wait Configuration # Concurrency Wait Configuration
# 并发等待配置 # 并发等待配置