diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index c0199258..1ba6b184 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -74,6 +74,7 @@ func provideCleanup( accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -147,6 +148,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 0b57334b..7a277112 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -168,7 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) - systemHandler := handler.ProvideSystemHandler(updateService) + idempotencyRepository := repository.NewIdempotencyRepository(client, db) + systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig) + systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) @@ -191,7 +193,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler) + idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) + idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -206,7 +210,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -243,6 +247,7 @@ func provideCleanup( accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -315,6 +320,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/ent/schema/idempotency_record.go b/backend/ent/schema/idempotency_record.go new file mode 100644 index 00000000..ed09ad65 --- /dev/null +++ b/backend/ent/schema/idempotency_record.go @@ -0,0 +1,50 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdempotencyRecord 幂等请求记录表。 +type IdempotencyRecord struct { + ent.Schema +} + +func (IdempotencyRecord) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "idempotency_records"}, + } +} + +func (IdempotencyRecord) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdempotencyRecord) Fields() []ent.Field { + return []ent.Field{ + field.String("scope").MaxLen(128), + field.String("idempotency_key_hash").MaxLen(64), + field.String("request_fingerprint").MaxLen(64), + field.String("status").MaxLen(32), + field.Int("response_status").Optional().Nillable(), + field.String("response_body").Optional().Nillable(), + field.String("error_reason").MaxLen(128).Optional().Nillable(), + field.Time("locked_until").Optional().Nillable(), + field.Time("expires_at"), + } +} + +func (IdempotencyRecord) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("scope", "idempotency_key_hash").Unique(), + index.Fields("expires_at"), + index.Fields("status", "locked_until"), + } +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3ccf6b3b..c4d4fdab 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -74,6 +74,7 @@ type Config struct { Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` Update UpdateConfig `mapstructure:"update"` + Idempotency IdempotencyConfig `mapstructure:"idempotency"` } type LogConfig struct { @@ -137,6 +138,25 @@ type UpdateConfig struct { ProxyURL string `mapstructure:"proxy_url"` } +type IdempotencyConfig struct { + // ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。 + ObserveOnly bool `mapstructure:"observe_only"` + // DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。 + DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` + // SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。 + SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` + // ProcessingTimeoutSeconds processing 状态锁超时(秒)。 + ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` + // FailedRetryBackoffSeconds 失败退避窗口(秒)。 + FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` + // MaxStoredResponseLen 持久化响应体最大长度(字节)。 + MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` + // CleanupIntervalSeconds 过期记录清理周期(秒)。 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` + // CleanupBatchSize 每次清理的最大记录数。 + CleanupBatchSize int `mapstructure:"cleanup_batch_size"` +} + type LinuxDoConnectConfig struct { Enabled bool `mapstructure:"enabled"` ClientID string `mapstructure:"client_id"` @@ -1117,6 +1137,16 @@ func setDefaults() { viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Idempotency + viper.SetDefault("idempotency.observe_only", true) + viper.SetDefault("idempotency.default_ttl_seconds", 86400) + viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) + viper.SetDefault("idempotency.processing_timeout_seconds", 30) + viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) + viper.SetDefault("idempotency.max_stored_response_len", 64*1024) + viper.SetDefault("idempotency.cleanup_interval_seconds", 60) + viper.SetDefault("idempotency.cleanup_batch_size", 500) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -1560,6 +1590,27 @@ func (c *Config) Validate() error { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") } } + if c.Idempotency.DefaultTTLSeconds <= 0 { + return fmt.Errorf("idempotency.default_ttl_seconds must be positive") + } + if c.Idempotency.SystemOperationTTLSeconds <= 0 { + return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive") + } + if c.Idempotency.ProcessingTimeoutSeconds <= 0 { + return fmt.Errorf("idempotency.processing_timeout_seconds must be positive") + } + if c.Idempotency.FailedRetryBackoffSeconds <= 0 { + return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive") + } + if c.Idempotency.MaxStoredResponseLen <= 0 { + return fmt.Errorf("idempotency.max_stored_response_len must be positive") + } + if c.Idempotency.CleanupIntervalSeconds <= 0 { + return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive") + } + if c.Idempotency.CleanupBatchSize <= 0 { + return fmt.Errorf("idempotency.cleanup_batch_size must be positive") + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 88aa62fa..b0402a3b 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -75,6 +75,42 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultIdempotencyConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = false, want true") + } + if cfg.Idempotency.DefaultTTLSeconds != 86400 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds) + } + if cfg.Idempotency.SystemOperationTTLSeconds != 3600 { + t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds) + } +} + +func TestLoadIdempotencyConfigFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false") + t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = true, want false") + } + if cfg.Idempotency.DefaultTTLSeconds != 600 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds) + } +} + func TestLoadSchedulingConfigFromEnv(t *testing.T) { resetViperWithJWTSecret(t) t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 34397696..4ce17219 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) { return } - dataPayload := req.Data - if err := validateDataHeader(dataPayload); err != nil { + if err := validateDataHeader(req.Data); err != nil { response.BadRequest(c, err.Error()) return } + executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + return h.importData(ctx, req) + }) +} + +func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) { skipDefaultGroupBind := true if req.SkipDefaultGroupBind != nil { skipDefaultGroupBind = *req.SkipDefaultGroupBind } + dataPayload := req.Data result := DataImportResult{} - existingProxies, err := h.listAllProxies(c.Request.Context()) + + existingProxies, err := h.listAllProxies(ctx) if err != nil { - response.ErrorFrom(c, err) - return + return result, err } proxyKeyToID := make(map[string]int64, len(existingProxies)) @@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) { proxyKeyToID[key] = existingID result.ProxyReused++ if normalizedStatus != "" { - if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ + if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus { + _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { continue } - created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ Name: defaultProxyName(item.Name), Protocol: item.Protocol, Host: item.Host, @@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) { Username: item.Username, Password: item.Password, }) - if err != nil { + if createErr != nil { result.ProxyFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "proxy", Name: item.Name, ProxyKey: key, - Message: err.Error(), + Message: createErr.Error(), }) continue } @@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.ProxyCreated++ if normalizedStatus != "" && normalizedStatus != created.Status { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ + _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { SkipDefaultGroupBind: skipDefaultGroupBind, } - if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { + if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil { result.AccountFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "account", @@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.AccountCreated++ } - response.Success(c, result) + return result, nil } func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 1aa0cf2b..a2a8dd43 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -405,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) { // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk - account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Platform: req.Platform, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, - Priority: req.Priority, - RateMultiplier: req.RateMultiplier, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: req.Name, + Notes: req.Notes, + Platform: req.Platform, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if execErr != nil { + return nil, execErr + } + return h.buildAccountResponseWithRuntime(ctx, account), nil }) if err != nil { // 检查是否为混合渠道错误 @@ -440,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) { return } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } response.ErrorFrom(c, err) return } - response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) } // Update handles updating an account @@ -838,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { return } - ctx := c.Request.Context() - success := 0 - failed := 0 - results := make([]gin.H, 0, len(req.Accounts)) + executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + success := 0 + failed := 0 + results := make([]gin.H, 0, len(req.Accounts)) - for _, item := range req.Accounts { - if item.RateMultiplier != nil && *item.RateMultiplier < 0 { - failed++ + for _, item := range req.Accounts { + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": "rate_multiplier must be >= 0", + }) + continue + } + + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk + + account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if err != nil { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": err.Error(), + }) + continue + } + success++ results = append(results, gin.H{ "name": item.Name, - "success": false, - "error": "rate_multiplier must be >= 0", + "id": account.ID, + "success": true, }) - continue } - skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk - - account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: item.ProxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: item.GroupIDs, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, - }) - if err != nil { - failed++ - results = append(results, gin.H{ - "name": item.Name, - "success": false, - "error": err.Error(), - }) - continue - } - success++ - results = append(results, gin.H{ - "name": item.Name, - "id": account.ID, - "success": true, - }) - } - - response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + return gin.H{ + "success": success, + "failed": failed, + "results": results, + }, nil }) } diff --git a/backend/internal/handler/admin/idempotency_helper.go b/backend/internal/handler/admin/idempotency_helper.go new file mode 100644 index 00000000..aa8eeaaf --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper.go @@ -0,0 +1,115 @@ +package admin + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type idempotencyStoreUnavailableMode int + +const ( + idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota + idempotencyStoreUnavailableFailOpen +) + +func executeAdminIdempotent( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) (*service.IdempotencyExecuteResult, error) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + return nil, err + } + return &service.IdempotencyExecuteResult{Data: data}, nil + } + + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + + return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) +} + +func executeAdminIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute) +} + +func executeAdminIdempotentJSONFailOpenOnStoreUnavailable( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute) +} + +func executeAdminIdempotentJSONWithMode( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + mode idempotencyStoreUnavailableMode, + execute func(context.Context) (any, error), +) { + result, err := executeAdminIdempotent(c, scope, payload, ttl, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + strategy := "fail_close" + if mode == idempotencyStoreUnavailableFailOpen { + strategy = "fail_open" + } + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy) + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy) + if mode == idempotencyStoreUnavailableFailOpen { + data, fallbackErr := execute(c.Request.Context()) + if fallbackErr != nil { + response.ErrorFrom(c, fallbackErr) + return + } + c.Header("X-Idempotency-Degraded", "store-unavailable") + response.Success(c, data) + return + } + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/admin/idempotency_helper_test.go b/backend/internal/handler/admin/idempotency_helper_test.go new file mode 100644 index 00000000..7dd86e16 --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package admin + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type storeUnavailableRepoStub struct{} + +func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable") +} + +func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-2") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded")) + require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue") +} + +type memoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub { + return &memoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(120 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + status1, _ = call() + }() + go func() { + defer wg.Done() + status2, _ = call() + }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once") + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 5a9cd7a0..9fd187fc 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { return } - proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), + executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + }) + if err != nil { + return nil, err + } + return dto.ProxyFromService(proxy), nil }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) } // Update handles updating a proxy diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 02752fea..7073061d 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -2,6 +2,7 @@ package admin import ( "bytes" + "context" "encoding/csv" "fmt" "strconv" @@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } - codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ - Count: req.Count, - Type: req.Type, - Value: req.Value, - GroupID: req.GroupID, - ValidityDays: req.ValidityDays, - }) - if err != nil { - response.ErrorFrom(c, err) - return - } + executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{ + Count: req.Count, + Type: req.Type, + Value: req.Value, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, + }) + if execErr != nil { + return nil, execErr + } - out := make([]dto.AdminRedeemCode, 0, len(codes)) - for i := range codes { - out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) - } - response.Success(c, out) + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + return out, nil + }) } // Delete handles deleting a redeem code diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 51995ab1..e5b6db13 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + SubscriptionID int64 `json:"subscription_id"` + Body AdjustSubscriptionRequest `json:"body"` + }{ + SubscriptionID: subscriptionID, + Body: req, } - - response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) + executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days) + if execErr != nil { + return nil, execErr + } + return dto.UserSubscriptionFromServiceAdmin(subscription), nil + }) } // Revoke handles revoking a subscription diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index 28c075aa..3e2022c7 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -1,11 +1,15 @@ package admin import ( + "context" "net/http" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -14,12 +18,14 @@ import ( // SystemHandler handles system-related operations type SystemHandler struct { updateSvc *service.UpdateService + lockSvc *service.SystemOperationLockService } // NewSystemHandler creates a new SystemHandler -func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { +func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler { return &SystemHandler{ updateSvc: updateSvc, + lockSvc: lockSvc, } } @@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { // PerformUpdate downloads and applies the update // POST /api/v1/admin/system/update func (h *SystemHandler) PerformUpdate(c *gin.Context) { - if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Update completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "update") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.PerformUpdate(ctx); err != nil { + releaseReason = "SYSTEM_UPDATE_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Update completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // Rollback restores the previous version // POST /api/v1/admin/system/rollback func (h *SystemHandler) Rollback(c *gin.Context) { - if err := h.updateSvc.Rollback(); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Rollback completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "rollback") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.Rollback(); err != nil { + releaseReason = "SYSTEM_ROLLBACK_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Rollback completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // RestartService restarts the systemd service // POST /api/v1/admin/system/restart func (h *SystemHandler) RestartService(c *gin.Context) { - // Schedule service restart in background after sending response - // This ensures the client receives the success response before the service restarts - go func() { - // Wait a moment to ensure the response is sent - time.Sleep(500 * time.Millisecond) - sysutil.RestartServiceAsync() - }() + operationID := buildSystemOperationID(c, "restart") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + succeeded := false + defer func() { + release("", succeeded) + }() - response.Success(c, gin.H{ - "message": "Service restart initiated", + // Schedule service restart in background after sending response + // This ensures the client receives the success response before the service restarts + go func() { + // Wait a moment to ensure the response is sent + time.Sleep(500 * time.Millisecond) + sysutil.RestartServiceAsync() + }() + succeeded = true + return gin.H{ + "message": "Service restart initiated", + "operation_id": lock.OperationID(), + }, nil }) } + +func (h *SystemHandler) acquireSystemLock( + ctx context.Context, + operationID string, +) (*service.SystemOperationLock, func(string, bool), error) { + if h.lockSvc == nil { + return nil, nil, service.ErrIdempotencyStoreUnavail + } + lock, err := h.lockSvc.Acquire(ctx, operationID) + if err != nil { + return nil, nil, err + } + release := func(reason string, succeeded bool) { + releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason) + } + return lock, release, nil +} + +func buildSystemOperationID(c *gin.Context, operation string) string { + key := strings.TrimSpace(c.GetHeader("Idempotency-Key")) + if key == "" { + return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36) + } + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key + hash := service.HashIdempotencyKey(seed) + if len(hash) > 24 { + hash = hash[:24] + } + return "sysop-" + hash +} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index de8f915f..5cbf18e6 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "net/http" "strconv" "strings" @@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { billingType = *filters.BillingType } - logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", - subject.UserID, - filters.StartTime.Format(time.RFC3339), - filters.EndTime.Format(time.RFC3339), - userID, - apiKeyID, - accountID, - groupID, - model, - stream, - billingType, - req.Timezone, - ) - - task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) - if err != nil { - logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + OperatorID int64 `json:"operator_id"` + Body CreateUsageCleanupTaskRequest `json:"body"` + }{ + OperatorID: subject.UserID, + Body: req, } + executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + stream, + billingType, + req.Timezone, + ) - logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) - response.Success(c, dto.UsageCleanupTaskFromService(task)) + task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID) + if err != nil { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + return nil, err + } + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + return dto.UsageCleanupTaskFromService(task), nil + }) } // CancelCleanupTask handles canceling a usage cleanup task diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 248caa4b..d85202e5 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { return } - user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + UserID int64 `json:"user_id"` + Body UpdateBalanceRequest `json:"body"` + }{ + UserID: userID, + Body: req, } - - response.Success(c, dto.UserFromServiceAdmin(user)) + executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes) + if execErr != nil { + return nil, execErr + } + return dto.UserFromServiceAdmin(user), nil + }) } // GetUserAPIKeys handles getting user's API keys diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index f1a18ad2..61762744 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -2,6 +2,7 @@ package handler import ( + "context" "strconv" "time" @@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } - key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, dto.APIKeyFromService(key)) + executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) + if err != nil { + return nil, err + } + return dto.APIKeyFromService(key), nil + }) } // Update handles updating an API key diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 5e327022..42ff4a84 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -2,6 +2,7 @@ package dto import ( + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -542,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult for i := range r.Subscriptions { subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) } + statuses := make(map[string]string, len(r.Statuses)) + for userID, status := range r.Statuses { + statuses[strconv.FormatInt(userID, 10)] = status + } return &BulkAssignResult{ SuccessCount: r.SuccessCount, + CreatedCount: r.CreatedCount, + ReusedCount: r.ReusedCount, FailedCount: r.FailedCount, Subscriptions: subs, Errors: r.Errors, + Statuses: statuses, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 70a8c792..0cd1b241 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -395,9 +395,12 @@ type AdminUserSubscription struct { type BulkAssignResult struct { SuccessCount int `json:"success_count"` + CreatedCount int `json:"created_count"` + ReusedCount int `json:"reused_count"` FailedCount int `json:"failed_count"` Subscriptions []AdminUserSubscription `json:"subscriptions"` Errors []string `json:"errors"` + Statuses map[string]string `json:"statuses,omitempty"` } // PromoCode 注册优惠码 diff --git a/backend/internal/handler/idempotency_helper.go b/backend/internal/handler/idempotency_helper.go new file mode 100644 index 00000000..bca63b6b --- /dev/null +++ b/backend/internal/handler/idempotency_helper.go @@ -0,0 +1,65 @@ +package handler + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func executeUserIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) + return + } + + actorScope := "user:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "user:" + strconv.FormatInt(subject.UserID, 10) + } + + result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close") + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope) + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/idempotency_helper_test.go b/backend/internal/handler/idempotency_helper_test.go new file mode 100644 index 00000000..e8213a2b --- /dev/null +++ b/backend/internal/handler/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package handler + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userStoreUnavailableRepoStub struct{} + +func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +type userMemoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub { + return &userMemoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func withUserSubject(userID int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + } +} + +func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(nil) + + var executed int + router := gin.New() + router.Use(withUserSubject(1)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, executed) +} + +func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.Use(withUserSubject(2)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "k1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed) +} + +func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newUserMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.Use(withUserSubject(3)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(80 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-user-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); status1, _ = call() }() + go func() { defer wg.Done(); status2, _ = call() }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load()) + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index b80fe4a9..79d583fd 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -53,8 +53,8 @@ func ProvideAdminHandlers( } // ProvideSystemHandler creates admin.SystemHandler with UpdateService -func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { - return admin.NewSystemHandler(updateService) +func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler { + return admin.NewSystemHandler(updateService, lockService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo @@ -77,6 +77,8 @@ func ProvideHandlers( soraGatewayHandler *SoraGatewayHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, + _ *service.IdempotencyCoordinator, + _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ Auth: authHandler, diff --git a/backend/internal/repository/idempotency_repo.go b/backend/internal/repository/idempotency_repo.go new file mode 100644 index 00000000..32f2faae --- /dev/null +++ b/backend/internal/repository/idempotency_repo.go @@ -0,0 +1,237 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type idempotencyRepository struct { + sql sqlExecutor +} + +func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository { + return &idempotencyRepository{sql: sqlDB} +} + +func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) { + if record == nil { + return false, nil + } + query := ` + INSERT INTO idempotency_records ( + scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at + ) VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (scope, idempotency_key_hash) DO NOTHING + RETURNING id, created_at, updated_at + ` + var createdAt time.Time + var updatedAt time.Time + err := scanSingleRow(ctx, r.sql, query, []any{ + record.Scope, + record.IdempotencyKeyHash, + record.RequestFingerprint, + record.Status, + record.LockedUntil, + record.ExpiresAt, + }, &record.ID, &createdAt, &updatedAt) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + record.CreatedAt = createdAt + record.UpdatedAt = updatedAt + return true, nil +} + +func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + query := ` + SELECT + id, scope, idempotency_key_hash, request_fingerprint, status, response_status, + response_body, error_reason, locked_until, expires_at, created_at, updated_at + FROM idempotency_records + WHERE scope = $1 AND idempotency_key_hash = $2 + ` + record := &service.IdempotencyRecord{} + var responseStatus sql.NullInt64 + var responseBody sql.NullString + var errorReason sql.NullString + var lockedUntil sql.NullTime + err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash}, + &record.ID, + &record.Scope, + &record.IdempotencyKeyHash, + &record.RequestFingerprint, + &record.Status, + &responseStatus, + &responseBody, + &errorReason, + &lockedUntil, + &record.ExpiresAt, + &record.CreatedAt, + &record.UpdatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + if responseStatus.Valid { + v := int(responseStatus.Int64) + record.ResponseStatus = &v + } + if responseBody.Valid { + v := responseBody.String + record.ResponseBody = &v + } + if errorReason.Valid { + v := errorReason.String + record.ErrorReason = &v + } + if lockedUntil.Valid { + v := lockedUntil.Time + record.LockedUntil = &v + } + return record, nil +} + +func (r *idempotencyRepository) TryReclaim( + ctx context.Context, + id int64, + fromStatus string, + now, newLockedUntil, newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET status = $2, + locked_until = $3, + error_reason = NULL, + updated_at = NOW(), + expires_at = $4 + WHERE id = $1 + AND status = $5 + AND (locked_until IS NULL OR locked_until <= $6) + ` + res, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusProcessing, + newLockedUntil, + newExpiresAt, + fromStatus, + now, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) ExtendProcessingLock( + ctx context.Context, + id int64, + requestFingerprint string, + newLockedUntil, + newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET locked_until = $2, + expires_at = $3, + updated_at = NOW() + WHERE id = $1 + AND status = $4 + AND request_fingerprint = $5 + ` + res, err := r.sql.ExecContext( + ctx, + query, + id, + newLockedUntil, + newExpiresAt, + service.IdempotencyStatusProcessing, + requestFingerprint, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + response_status = $3, + response_body = $4, + error_reason = NULL, + locked_until = NULL, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusSucceeded, + responseStatus, + responseBody, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + error_reason = $3, + locked_until = $4, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusFailedRetryable, + errorReason, + lockedUntil, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) { + if limit <= 0 { + limit = 500 + } + query := ` + WITH victims AS ( + SELECT id + FROM idempotency_records + WHERE expires_at <= $1 + ORDER BY expires_at ASC + LIMIT $2 + ) + DELETE FROM idempotency_records + WHERE id IN (SELECT id FROM victims) + ` + res, err := r.sql.ExecContext(ctx, query, now, limit) + if err != nil { + return 0, err + } + return res.RowsAffected() +} diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go new file mode 100644 index 00000000..29a01051 --- /dev/null +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -0,0 +1,144 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-create"), + IdempotencyKeyHash: uniqueTestValue(t, "idem-hash"), + RequestFingerprint: uniqueTestValue(t, "idem-fp"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + require.NotZero(t, record.ID) + + duplicate := &service.IdempotencyRecord{ + Scope: record.Scope, + IdempotencyKeyHash: record.IdempotencyKeyHash, + RequestFingerprint: uniqueTestValue(t, "idem-fp-other"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err = repo.CreateProcessing(ctx, duplicate) + require.NoError(t, err) + require.False(t, owner, "same scope+key hash should be de-duplicated") +} + +func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-reclaim"), + IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-reclaim"), + RequestFingerprint: uniqueTestValue(t, "idem-fp-reclaim"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(-2*time.Second), + now.Add(24*time.Hour), + )) + + newLockedUntil := now.Add(20 * time.Second) + reclaimed, err := repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + newLockedUntil, + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim") + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusProcessing, got.Status) + require.NotNil(t, got.LockedUntil) + require.True(t, got.LockedUntil.After(now)) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(20*time.Second), + now.Add(24*time.Hour), + )) + + reclaimed, err = repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + now.Add(40*time.Second), + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.False(t, reclaimed, "within lock window should not reclaim") +} + +func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-success"), + IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-success"), + RequestFingerprint: uniqueTestValue(t, "idem-fp-success"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour))) + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusSucceeded, got.Status) + require.NotNil(t, got.ResponseStatus) + require.Equal(t, 200, *got.ResponseStatus) + require.NotNil(t, got.ResponseBody) + require.Equal(t, `{"ok":true}`, *got.ResponseBody) + require.Nil(t, got.LockedUntil) +} + +func ptrTime(v time.Time) *time.Time { + return &v +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index d91f654b..0878c43d 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -60,6 +60,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index c88906e3..c5e1cfab 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -35,6 +35,8 @@ var ( const ( apiKeyMaxErrorsPerHour = 20 apiKeyLastUsedMinTouch = 30 * time.Second + // DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。 + apiKeyLastUsedFailBackoff = 5 * time.Second ) type APIKeyRepository interface { @@ -129,7 +131,7 @@ type APIKeyService struct { authCacheL1 *ristretto.Cache authCfg apiKeyAuthCacheConfig authGroup singleflight.Group - lastUsedTouchL1 sync.Map // keyID -> time.Time + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) lastUsedTouchSF singleflight.Group } @@ -574,7 +576,7 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { now := time.Now() if v, ok := s.lastUsedTouchL1.Load(keyID); ok { - if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { return nil } } @@ -582,15 +584,16 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { latest := time.Now() if v, ok := s.lastUsedTouchL1.Load(keyID); ok { - if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { return nil, nil } } if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff)) return nil, fmt.Errorf("touch api key last used: %w", err) } - s.lastUsedTouchL1.Store(keyID, latest) + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch)) return nil, nil }) return err diff --git a/backend/internal/service/api_key_service_touch_last_used_test.go b/backend/internal/service/api_key_service_touch_last_used_test.go index 5c750ec5..b49bf9ce 100644 --- a/backend/internal/service/api_key_service_touch_last_used_test.go +++ b/backend/internal/service/api_key_service_touch_last_used_test.go @@ -79,8 +79,27 @@ func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) { require.ErrorContains(t, err, "touch api key last used") require.Equal(t, []int64{123}, repo.touchedIDs) - _, ok := svc.lastUsedTouchL1.Load(int64(123)) - require.False(t, ok, "failed touch should not update debounce cache") + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "failed touch should still update retry debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + firstErr := svc.TouchLastUsed(context.Background(), 456) + require.Error(t, firstErr) + require.ErrorContains(t, firstErr, "touch api key last used") + + secondErr := svc.TouchLastUsed(context.Background(), 456) + require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry") + require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again") } type touchSingleflightRepo struct { diff --git a/backend/internal/service/idempotency.go b/backend/internal/service/idempotency.go new file mode 100644 index 00000000..2a86bd60 --- /dev/null +++ b/backend/internal/service/idempotency.go @@ -0,0 +1,471 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const ( + IdempotencyStatusProcessing = "processing" + IdempotencyStatusSucceeded = "succeeded" + IdempotencyStatusFailedRetryable = "failed_retryable" +) + +var ( + ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required") + ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid") + ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload") + ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing") + ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window") + ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable") + ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload") +) + +type IdempotencyRecord struct { + ID int64 + Scope string + IdempotencyKeyHash string + RequestFingerprint string + Status string + ResponseStatus *int + ResponseBody *string + ErrorReason *string + LockedUntil *time.Time + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type IdempotencyRepository interface { + CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error) + GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error) + TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) + ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) + MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error + MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error + DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) +} + +type IdempotencyConfig struct { + DefaultTTL time.Duration + SystemOperationTTL time.Duration + ProcessingTimeout time.Duration + FailedRetryBackoff time.Duration + MaxStoredResponseLen int + ObserveOnly bool +} + +func DefaultIdempotencyConfig() IdempotencyConfig { + return IdempotencyConfig{ + DefaultTTL: 24 * time.Hour, + SystemOperationTTL: 1 * time.Hour, + ProcessingTimeout: 30 * time.Second, + FailedRetryBackoff: 5 * time.Second, + MaxStoredResponseLen: 64 * 1024, + ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断 + } +} + +type IdempotencyExecuteOptions struct { + Scope string + ActorScope string + Method string + Route string + IdempotencyKey string + Payload any + TTL time.Duration + RequireKey bool +} + +type IdempotencyExecuteResult struct { + Data any + Replayed bool +} + +type IdempotencyCoordinator struct { + repo IdempotencyRepository + cfg IdempotencyConfig +} + +var ( + defaultIdempotencyMu sync.RWMutex + defaultIdempotencySvc *IdempotencyCoordinator +) + +func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) { + defaultIdempotencyMu.Lock() + defaultIdempotencySvc = svc + defaultIdempotencyMu.Unlock() +} + +func DefaultIdempotencyCoordinator() *IdempotencyCoordinator { + defaultIdempotencyMu.RLock() + defer defaultIdempotencyMu.RUnlock() + return defaultIdempotencySvc +} + +func DefaultWriteIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().DefaultTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 { + return coordinator.cfg.DefaultTTL + } + return defaultTTL +} + +func DefaultSystemOperationIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 { + return coordinator.cfg.SystemOperationTTL + } + return defaultTTL +} + +func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator { + return &IdempotencyCoordinator{ + repo: repo, + cfg: cfg, + } +} + +func NormalizeIdempotencyKey(raw string) (string, error) { + key := strings.TrimSpace(raw) + if key == "" { + return "", nil + } + if len(key) > 128 { + return "", ErrIdempotencyKeyInvalid + } + for _, r := range key { + if r < 33 || r > 126 { + return "", ErrIdempotencyKeyInvalid + } + } + return key, nil +} + +func HashIdempotencyKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) { + if method == "" { + method = "POST" + } + if route == "" { + route = "/" + } + if actorScope == "" { + actorScope = "anonymous" + } + + raw, err := json.Marshal(payload) + if err != nil { + return "", ErrIdempotencyInvalidPayload.WithCause(err) + } + sum := sha256.Sum256([]byte( + strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw), + )) + return hex.EncodeToString(sum[:]), nil +} + +func RetryAfterSecondsFromError(err error) int { + appErr := new(infraerrors.ApplicationError) + if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil { + return 0 + } + v := strings.TrimSpace(appErr.Metadata["retry_after"]) + if v == "" { + return 0 + } + seconds, convErr := strconv.Atoi(v) + if convErr != nil || seconds <= 0 { + return 0 + } + return seconds +} + +func (c *IdempotencyCoordinator) Execute( + ctx context.Context, + opts IdempotencyExecuteOptions, + execute func(context.Context) (any, error), +) (*IdempotencyExecuteResult, error) { + if execute == nil { + return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil") + } + + key, err := NormalizeIdempotencyKey(opts.IdempotencyKey) + if err != nil { + return nil, err + } + if key == "" { + if opts.RequireKey && !c.cfg.ObserveOnly { + return nil, ErrIdempotencyKeyRequired + } + data, execErr := execute(ctx) + if execErr != nil { + return nil, execErr + } + return &IdempotencyExecuteResult{Data: data}, nil + } + if c.repo == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil") + return nil, ErrIdempotencyStoreUnavail + } + + if opts.Scope == "" { + return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required") + } + + fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload) + if err != nil { + return nil, err + } + + ttl := opts.TTL + if ttl <= 0 { + ttl = c.cfg.DefaultTTL + } + now := time.Now() + expiresAt := now.Add(ttl) + lockedUntil := now.Add(c.cfg.ProcessingTimeout) + keyHash := HashIdempotencyKey(key) + + record := &IdempotencyRecord{ + Scope: opts.Scope, + IdempotencyKeyHash: keyHash, + RequestFingerprint: fingerprint, + Status: IdempotencyStatusProcessing, + LockedUntil: &lockedUntil, + ExpiresAt: expiresAt, + } + + owner, err := c.repo.CreateProcessing(ctx, record) + if err != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "create_processing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(err) + } + if owner { + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{ + "claim_mode": "new", + }) + } + if !owner { + existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if getErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(getErr) + } + if existing == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing", + }) + return nil, ErrIdempotencyStoreUnavail + } + if existing.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + reclaimedByExpired := false + if !existing.ExpiresAt.After(now) { + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{ + "operation": "try_reclaim_expired", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if taken { + reclaimedByExpired = true + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{ + "claim_mode": "expired_reclaim", + }) + record.ID = existing.ID + } else { + latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if latestErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr) + } + if latest == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail + } + if latest.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + existing = latest + } + } + + if !reclaimedByExpired { + switch existing.Status { + case IdempotencyStatusSucceeded: + data, parseErr := c.decodeStoredResponse(existing.ResponseBody) + if parseErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{ + "operation": "decode_stored_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr) + } + recordIdempotencyReplay(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil) + return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil + case IdempotencyStatusProcessing: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + case IdempotencyStatusFailedRetryable: + if existing.LockedUntil != nil && existing.LockedUntil.After(now) { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"}) + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now) + } + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{ + "operation": "try_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if !taken { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{ + "conflict": "reclaim_race", + }) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + } + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{ + "claim_mode": "reclaim", + }) + record.ID = existing.ID + default: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{ + "status": existing.Status, + }) + return nil, ErrIdempotencyKeyConflict + } + } + } + + if record.ID == 0 { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "record_id_missing", + }) + return nil, ErrIdempotencyStoreUnavail + } + + execStart := time.Now() + defer func() { + recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil) + }() + + data, execErr := execute(ctx) + if execErr != nil { + backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff) + reason := infraerrors.Reason(execErr) + if reason == "" { + reason = "EXECUTION_FAILED" + } + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{ + "reason": reason, + }) + if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_failed_retryable", + }) + } + return nil, execErr + } + + storedBody, marshalErr := c.marshalStoredResponse(data) + if marshalErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "marshal_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr) + } + if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_succeeded", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(markErr) + } + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil) + + return &IdempotencyExecuteResult{Data: data}, nil +} + +func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error { + if lockedUntil == nil { + return base + } + sec := int(lockedUntil.Sub(now).Seconds()) + if sec <= 0 { + sec = 1 + } + return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)}) +} + +func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) { + raw, err := json.Marshal(data) + if err != nil { + return "", err + } + redacted := logredact.RedactText(string(raw)) + if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen { + redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)" + } + return redacted, nil +} + +func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) { + if stored == nil || strings.TrimSpace(*stored) == "" { + return map[string]any{}, nil + } + var out any + if err := json.Unmarshal([]byte(*stored), &out); err != nil { + return nil, fmt.Errorf("decode stored response: %w", err) + } + return out, nil +} diff --git a/backend/internal/service/idempotency_cleanup_service.go b/backend/internal/service/idempotency_cleanup_service.go new file mode 100644 index 00000000..aaf6949a --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service.go @@ -0,0 +1,91 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。 +type IdempotencyCleanupService struct { + repo IdempotencyRepository + interval time.Duration + batch int + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService { + interval := 60 * time.Second + batch := 500 + if cfg != nil { + if cfg.Idempotency.CleanupIntervalSeconds > 0 { + interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second + } + if cfg.Idempotency.CleanupBatchSize > 0 { + batch = cfg.Idempotency.CleanupBatchSize + } + } + return &IdempotencyCleanupService{ + repo: repo, + interval: interval, + batch: batch, + stopCh: make(chan struct{}), + } +} + +func (s *IdempotencyCleanupService) Start() { + if s == nil || s.repo == nil { + return + } + s.startOnce.Do(func() { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch) + go s.runLoop() + }) +} + +func (s *IdempotencyCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped") + }) +} + +func (s *IdempotencyCleanupService) runLoop() { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + // 启动后先清理一轮,防止重启后积压。 + s.cleanupOnce() + + for { + select { + case <-ticker.C: + s.cleanupOnce() + case <-s.stopCh: + return + } + } +} + +func (s *IdempotencyCleanupService) cleanupOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch) + if err != nil { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err) + return + } + if deleted > 0 { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted) + } +} diff --git a/backend/internal/service/idempotency_cleanup_service_test.go b/backend/internal/service/idempotency_cleanup_service_test.go new file mode 100644 index 00000000..556ff364 --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type idempotencyCleanupRepoStub struct { + deleteCalls int + lastLimit int + deleteErr error +} + +func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) { + r.deleteCalls++ + r.lastLimit = limit + if r.deleteErr != nil { + return 0, r.deleteErr + } + return 1, nil +} + +func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + cfg := &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupIntervalSeconds: 7, + CleanupBatchSize: 321, + }, + } + svc := NewIdempotencyCleanupService(repo, cfg) + require.Equal(t, 7*time.Second, svc.interval) + require.Equal(t, 321, svc.batch) +} + +func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + svc := NewIdempotencyCleanupService(repo, &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupBatchSize: 99, + }, + }) + + svc.cleanupOnce() + require.Equal(t, 1, repo.deleteCalls) + require.Equal(t, 99, repo.lastLimit) +} diff --git a/backend/internal/service/idempotency_observability.go b/backend/internal/service/idempotency_observability.go new file mode 100644 index 00000000..f1bf2df2 --- /dev/null +++ b/backend/internal/service/idempotency_observability.go @@ -0,0 +1,171 @@ +package service + +import ( + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。 +type IdempotencyMetricsSnapshot struct { + ClaimTotal uint64 `json:"claim_total"` + ReplayTotal uint64 `json:"replay_total"` + ConflictTotal uint64 `json:"conflict_total"` + RetryBackoffTotal uint64 `json:"retry_backoff_total"` + ProcessingDurationCount uint64 `json:"processing_duration_count"` + ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"` + StoreUnavailableTotal uint64 `json:"store_unavailable_total"` +} + +type idempotencyMetrics struct { + claimTotal atomic.Uint64 + replayTotal atomic.Uint64 + conflictTotal atomic.Uint64 + retryBackoffTotal atomic.Uint64 + processingDurationCount atomic.Uint64 + processingDurationMicros atomic.Uint64 + storeUnavailableTotal atomic.Uint64 +} + +var defaultIdempotencyMetrics idempotencyMetrics + +// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。 +func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot { + totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load() + return IdempotencyMetricsSnapshot{ + ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(), + ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(), + ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(), + RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(), + ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(), + ProcessingDurationTotalMs: float64(totalMicros) / 1000.0, + StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(), + } +} + +func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.claimTotal.Add(1) + logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.replayTotal.Add(1) + logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.conflictTotal.Add(1) + logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.retryBackoffTotal.Add(1) + logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) { + if duration < 0 { + duration = 0 + } + defaultIdempotencyMetrics.processingDurationCount.Add(1) + defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds())) + logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs) +} + +// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。 +func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) { + defaultIdempotencyMetrics.storeUnavailableTotal.Add(1) + attrs := map[string]string{} + if strategy != "" { + attrs["strategy"] = strategy + } + logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs) +} + +func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyAudit]") + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " key_hash=") + builderWriteString(&b, safeAuditField(keyHash)) + builderWriteString(&b, " state_transition=") + builderWriteString(&b, safeAuditField(stateTransition)) + builderWriteString(&b, " replayed=") + builderWriteString(&b, strconv.FormatBool(replayed)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyMetric]") + builderWriteString(&b, " name=") + builderWriteString(&b, safeAuditField(name)) + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " value=") + builderWriteString(&b, safeAuditField(value)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) { + if len(attrs) == 0 { + return + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + builderWriteByte(builder, ' ') + builderWriteString(builder, k) + builderWriteByte(builder, '=') + builderWriteString(builder, safeAuditField(attrs[k])) + } +} + +func safeAuditField(v string) string { + value := strings.TrimSpace(v) + if value == "" { + return "-" + } + // 日志按 key=value 输出,替换空白避免解析歧义。 + value = strings.ReplaceAll(value, "\n", "_") + value = strings.ReplaceAll(value, "\r", "_") + value = strings.ReplaceAll(value, "\t", "_") + value = strings.ReplaceAll(value, " ", "_") + return value +} + +func resetIdempotencyMetricsForTest() { + defaultIdempotencyMetrics.claimTotal.Store(0) + defaultIdempotencyMetrics.replayTotal.Store(0) + defaultIdempotencyMetrics.conflictTotal.Store(0) + defaultIdempotencyMetrics.retryBackoffTotal.Store(0) + defaultIdempotencyMetrics.processingDurationCount.Store(0) + defaultIdempotencyMetrics.processingDurationMicros.Store(0) + defaultIdempotencyMetrics.storeUnavailableTotal.Store(0) +} + +func builderWriteString(builder *strings.Builder, value string) { + _, _ = builder.WriteString(value) +} + +func builderWriteByte(builder *strings.Builder, value byte) { + _ = builder.WriteByte(value) +} diff --git a/backend/internal/service/idempotency_test.go b/backend/internal/service/idempotency_test.go new file mode 100644 index 00000000..6ff75d1c --- /dev/null +++ b/backend/internal/service/idempotency_test.go @@ -0,0 +1,805 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type inMemoryIdempotencyRepo struct { + mu sync.Mutex + nextID int64 + data map[string]*IdempotencyRecord +} + +func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo { + return &inMemoryIdempotencyRepo{ + nextID: 1, + data: make(map[string]*IdempotencyRecord), + } +} + +func (r *inMemoryIdempotencyRepo) key(scope, hash string) string { + return scope + "|" + hash +} + +func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + return &out +} + +func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + rec := cloneRecord(record) + rec.ID = r.nextID + rec.CreatedAt = time.Now() + rec.UpdatedAt = rec.CreatedAt + r.nextID++ + r.data[k] = rec + record.ID = rec.ID + record.CreatedAt = rec.CreatedAt + record.UpdatedAt = rec.UpdatedAt + return true, nil +} + +func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return cloneRecord(r.data[r.key(scope, keyHash)]), nil +} + +func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = nil + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = &errorReason + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) { + r.mu.Lock() + defer r.mu.Unlock() + var deleted int64 + for k, rec := range r.data { + if !rec.ExpiresAt.After(now) { + delete(r.data, k) + deleted++ + } + } + return deleted, nil +} + +func TestIdempotencyCoordinator_RequireKey(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ObserveOnly = false + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "admin:1", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired)) +} + +func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-1", + Payload: map[string]any{"a": 1}, + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.False(t, first.Replayed) + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.True(t, second.Replayed) + require.Equal(t, 1, execCount, "second request should replay without executing business logic") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) +} + +func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.expired", + Method: "POST", + Route: "/test/expired", + ActorScope: "user:99", + RequireKey: true, + IdempotencyKey: "expired-case", + Payload: map[string]any{"k": "v"}, + } + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, first) + require.False(t, first.Replayed) + require.Equal(t, 1, execCount) + + keyHash := HashIdempotencyKey(opts.IdempotencyKey) + repo.mu.Lock() + existing := repo.data[repo.key(opts.Scope, keyHash)] + require.NotNil(t, existing) + existing.ExpiresAt = time.Now().Add(-time.Second) + repo.mu.Unlock() + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, second) + require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again") + require.Equal(t, 2, execCount) + + third, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, third) + require.True(t, third.Replayed) + payload, ok := third.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), payload["count"]) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1)) +} + +func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 2}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict)) + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ConflictTotal) +} + +func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.FailedRetryBackoff = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-3", + Payload: map[string]any{"a": 1}, + } + + _, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error") + }) + require.Error(t, err) + + _, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff)) + require.Greater(t, RetryAfterSecondsFromError(err), 0) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) + require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1)) +} + +func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.concurrent", + Method: "POST", + Route: "/test/concurrent", + ActorScope: "user:7", + RequireKey: true, + IdempotencyKey: "concurrent-case", + Payload: map[string]any{"v": 1}, + } + + var execCount int32 + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + time.Sleep(80 * time.Millisecond) + return map[string]any{"ok": true}, nil + }) + }() + } + wg.Wait() + + replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.True(t, replayed.Replayed) + require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) +} + +type failingIdempotencyRepo struct{} + +func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) { + resetIdempotencyMetricsForTest() + coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope.unavailable", + Method: "POST", + Route: "/test/unavailable", + ActorScope: "admin:1", + RequireKey: true, + IdempotencyKey: "case-unavailable", + Payload: map[string]any{"v": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1)) +} + +func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) { + SetDefaultIdempotencyCoordinator(nil) + require.Nil(t, DefaultIdempotencyCoordinator()) + require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL()) + require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL()) + + coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: 2 * time.Hour, + SystemOperationTTL: 15 * time.Minute, + ProcessingTimeout: 10 * time.Second, + FailedRetryBackoff: 3 * time.Second, + ObserveOnly: false, + }) + SetDefaultIdempotencyCoordinator(coordinator) + t.Cleanup(func() { + SetDefaultIdempotencyCoordinator(nil) + }) + + require.Same(t, coordinator, DefaultIdempotencyCoordinator()) + require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL()) + require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL()) +} + +func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) { + key, err := NormalizeIdempotencyKey(" abc-123 ") + require.NoError(t, err) + require.Equal(t, "abc-123", key) + + key, err = NormalizeIdempotencyKey("") + require.NoError(t, err) + require.Equal(t, "", key) + + _, err = NormalizeIdempotencyKey(string(make([]byte, 129))) + require.Error(t, err) + + _, err = NormalizeIdempotencyKey("bad\nkey") + require.Error(t, err) + + fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1}) + require.NoError(t, err) + require.NotEmpty(t, fp1) + fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1}) + require.NoError(t, err) + require.Equal(t, fp1, fp2) + + _, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)}) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err)) +} + +func TestRetryAfterSecondsFromErrorBranches(t *testing.T) { + require.Equal(t, 0, RetryAfterSecondsFromError(nil)) + require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain"))) + + err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"}) + require.Equal(t, 12, RetryAfterSecondsFromError(err)) + + err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"}) + require.Equal(t, 0, RetryAfterSecondsFromError(err)) +} + +func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, nil) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err)) + + called := 0 + result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + called++ + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.Equal(t, 1, called) + require.NotNil(t, result) + require.False(t, result.Replayed) +} + +type noIDOwnerRepo struct{} + +func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return true, nil +} +func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil } +func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil } + +func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) { + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(nil, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + IdempotencyKey: "k2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err)) + + coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-no-id", + IdempotencyKey: "k3", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) +} + +type conflictBranchRepo struct { + existing *IdempotencyRecord + tryReclaimErr error + tryReclaimOK bool +} + +func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return cloneRecord(r.existing), nil +} +func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + if r.tryReclaimErr != nil { + return false, r.tryReclaimErr + } + return r.tryReclaimOK, nil +} +func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, nil +} + +func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) { + now := time.Now() + fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1}) + require.NoError(t, err) + badBody := "{bad-json" + repo := &conflictBranchRepo{ + existing: &IdempotencyRecord{ + ID: 1, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusSucceeded, + ResponseBody: &badBody, + ExpiresAt: now.Add(time.Hour), + }, + } + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 2, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: "unknown", + ExpiresAt: now.Add(time.Hour), + } + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 3, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusFailedRetryable, + LockedUntil: ptrTime(now.Add(-time.Second)), + ExpiresAt: now.Add(time.Hour), + } + repo.tryReclaimErr = errors.New("reclaim down") + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.tryReclaimErr = nil + repo.tryReclaimOK = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err)) +} + +type markBehaviorRepo struct { + inMemoryIdempotencyRepo + failMarkSucceeded bool + failMarkFailed bool +} + +func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + if r.failMarkSucceeded { + return errors.New("mark succeeded failed") + } + return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt) +} + +func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + if r.failMarkFailed { + return errors.New("mark failed retryable failed") + } + return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt) +} + +func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) { + repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()} + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + repo.failMarkSucceeded = true + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-success", + IdempotencyKey: "k1", + Method: "POST", + Route: "/ok", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkSucceeded = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-marshal", + IdempotencyKey: "k2", + Method: "POST", + Route: "/bad", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"bad": make(chan int)}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkFailed = true + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-fail", + IdempotencyKey: "k3", + Method: "POST", + Route: "/fail", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return nil, errors.New("plain failure") + }) + require.Error(t, err) + require.Equal(t, "plain failure", err.Error()) +} + +func TestIdempotencyCoordinator_HelperBranches(t *testing.T) { + c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: time.Hour, + SystemOperationTTL: time.Hour, + ProcessingTimeout: time.Second, + FailedRetryBackoff: time.Second, + MaxStoredResponseLen: 12, + ObserveOnly: false, + }) + + // conflictWithRetryAfter without locked_until should return base error. + base := ErrIdempotencyInProgress + err := c.conflictWithRetryAfter(base, nil, time.Now()) + require.Equal(t, infraerrors.Code(base), infraerrors.Code(err)) + + // marshalStoredResponse should truncate. + body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"}) + require.NoError(t, err) + require.Contains(t, body, "...(truncated)") + + // decodeStoredResponse empty and invalid json. + out, err := c.decodeStoredResponse(nil) + require.NoError(t, err) + _, ok := out.(map[string]any) + require.True(t, ok) + + invalid := "{invalid" + _, err = c.decodeStoredResponse(&invalid) + require.Error(t, err) +} diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go new file mode 100644 index 00000000..0defafba --- /dev/null +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -0,0 +1,389 @@ +package service + +import ( + "context" + "strconv" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type groupRepoNoop struct{} + +func (groupRepoNoop) Create(context.Context, *Group) error { panic("unexpected Create call") } +func (groupRepoNoop) GetByID(context.Context, int64) (*Group, error) { + panic("unexpected GetByID call") +} +func (groupRepoNoop) GetByIDLite(context.Context, int64) (*Group, error) { + panic("unexpected GetByIDLite call") +} +func (groupRepoNoop) Update(context.Context, *Group) error { panic("unexpected Update call") } +func (groupRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") } +func (groupRepoNoop) DeleteCascade(context.Context, int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} +func (groupRepoNoop) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} +func (groupRepoNoop) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} +func (groupRepoNoop) ListActive(context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} +func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} +func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { + panic("unexpected ExistsByName call") +} +func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { + panic("unexpected GetAccountCount call") +} +func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} +func (groupRepoNoop) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} +func (groupRepoNoop) BindAccountsToGroup(context.Context, int64, []int64) error { + panic("unexpected BindAccountsToGroup call") +} +func (groupRepoNoop) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error { + panic("unexpected UpdateSortOrders call") +} + +type subscriptionGroupRepoStub struct { + groupRepoNoop + group *Group +} + +func (s *subscriptionGroupRepoStub) GetByID(context.Context, int64) (*Group, error) { + return s.group, nil +} + +type userSubRepoNoop struct{} + +func (userSubRepoNoop) Create(context.Context, *UserSubscription) error { + panic("unexpected Create call") +} +func (userSubRepoNoop) GetByID(context.Context, int64) (*UserSubscription, error) { + panic("unexpected GetByID call") +} +func (userSubRepoNoop) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) { + panic("unexpected GetByUserIDAndGroupID call") +} +func (userSubRepoNoop) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) { + panic("unexpected GetActiveByUserIDAndGroupID call") +} +func (userSubRepoNoop) Update(context.Context, *UserSubscription) error { + panic("unexpected Update call") +} +func (userSubRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") } +func (userSubRepoNoop) ListByUserID(context.Context, int64) ([]UserSubscription, error) { + panic("unexpected ListByUserID call") +} +func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) { + panic("unexpected ListActiveByUserID call") +} +func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { + panic("unexpected List call") +} +func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { + panic("unexpected ExistsByUserIDAndGroupID call") +} +func (userSubRepoNoop) ExtendExpiry(context.Context, int64, time.Time) error { + panic("unexpected ExtendExpiry call") +} +func (userSubRepoNoop) UpdateStatus(context.Context, int64, string) error { + panic("unexpected UpdateStatus call") +} +func (userSubRepoNoop) UpdateNotes(context.Context, int64, string) error { + panic("unexpected UpdateNotes call") +} +func (userSubRepoNoop) ActivateWindows(context.Context, int64, time.Time) error { + panic("unexpected ActivateWindows call") +} +func (userSubRepoNoop) ResetDailyUsage(context.Context, int64, time.Time) error { + panic("unexpected ResetDailyUsage call") +} +func (userSubRepoNoop) ResetWeeklyUsage(context.Context, int64, time.Time) error { + panic("unexpected ResetWeeklyUsage call") +} +func (userSubRepoNoop) ResetMonthlyUsage(context.Context, int64, time.Time) error { + panic("unexpected ResetMonthlyUsage call") +} +func (userSubRepoNoop) IncrementUsage(context.Context, int64, float64) error { + panic("unexpected IncrementUsage call") +} +func (userSubRepoNoop) BatchUpdateExpiredStatus(context.Context) (int64, error) { + panic("unexpected BatchUpdateExpiredStatus call") +} + +type subscriptionUserSubRepoStub struct { + userSubRepoNoop + + nextID int64 + byID map[int64]*UserSubscription + byUserGroup map[string]*UserSubscription + createCalls int +} + +func newSubscriptionUserSubRepoStub() *subscriptionUserSubRepoStub { + return &subscriptionUserSubRepoStub{ + nextID: 1, + byID: make(map[int64]*UserSubscription), + byUserGroup: make(map[string]*UserSubscription), + } +} + +func (s *subscriptionUserSubRepoStub) key(userID, groupID int64) string { + return strconvFormatInt(userID) + ":" + strconvFormatInt(groupID) +} + +func (s *subscriptionUserSubRepoStub) seed(sub *UserSubscription) { + if sub == nil { + return + } + cp := *sub + if cp.ID == 0 { + cp.ID = s.nextID + s.nextID++ + } + s.byID[cp.ID] = &cp + s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp +} + +func (s *subscriptionUserSubRepoStub) ExistsByUserIDAndGroupID(_ context.Context, userID, groupID int64) (bool, error) { + _, ok := s.byUserGroup[s.key(userID, groupID)] + return ok, nil +} + +func (s *subscriptionUserSubRepoStub) GetByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { + sub := s.byUserGroup[s.key(userID, groupID)] + if sub == nil { + return nil, ErrSubscriptionNotFound + } + cp := *sub + return &cp, nil +} + +func (s *subscriptionUserSubRepoStub) Create(_ context.Context, sub *UserSubscription) error { + if sub == nil { + return nil + } + s.createCalls++ + cp := *sub + if cp.ID == 0 { + cp.ID = s.nextID + s.nextID++ + } + sub.ID = cp.ID + s.byID[cp.ID] = &cp + s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp + return nil +} + +func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) { + sub := s.byID[id] + if sub == nil { + return nil, ErrSubscriptionNotFound + } + cp := *sub + return &cp, nil +} + +func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) { + start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + subRepo.seed(&UserSubscription{ + ID: 10, + UserID: 1001, + GroupID: 1, + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 30), + Notes: "init", + }) + + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 1001, + GroupID: 1, + ValidityDays: 30, + Notes: "init", + }) + require.NoError(t, err) + require.Equal(t, int64(10), sub.ID) + require.Equal(t, 0, subRepo.createCalls, "reuse should not create new subscription") +} + +func TestAssignSubscriptionConflictWhenSemanticsMismatch(t *testing.T) { + start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + subRepo.seed(&UserSubscription{ + ID: 11, + UserID: 2001, + GroupID: 1, + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 30), + Notes: "old-note", + }) + + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + _, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 2001, + GroupID: 1, + ValidityDays: 30, + Notes: "new-note", + }) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_ASSIGN_CONFLICT", infraerrorsReason(err)) + require.Equal(t, 0, subRepo.createCalls, "conflict should not create or mutate existing subscription") +} + +func TestBulkAssignSubscriptionCreatedReusedAndConflict(t *testing.T) { + start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + // user 1: 语义一致,可 reused + subRepo.seed(&UserSubscription{ + ID: 21, + UserID: 1, + GroupID: 1, + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 30), + Notes: "same-note", + }) + // user 3: 语义冲突(有效期不一致),应 failed + subRepo.seed(&UserSubscription{ + ID: 23, + UserID: 3, + GroupID: 1, + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 60), + Notes: "same-note", + }) + + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + result, err := svc.BulkAssignSubscription(context.Background(), &BulkAssignSubscriptionInput{ + UserIDs: []int64{1, 2, 3}, + GroupID: 1, + ValidityDays: 30, + AssignedBy: 9, + Notes: "same-note", + }) + require.NoError(t, err) + require.Equal(t, 2, result.SuccessCount) + require.Equal(t, 1, result.CreatedCount) + require.Equal(t, 1, result.ReusedCount) + require.Equal(t, 1, result.FailedCount) + require.Equal(t, "reused", result.Statuses[1]) + require.Equal(t, "created", result.Statuses[2]) + require.Equal(t, "failed", result.Statuses[3]) + require.Equal(t, 1, subRepo.createCalls) +} + +func TestAssignSubscriptionKeepsWorkingWhenIdempotencyStoreUnavailable(t *testing.T) { + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + SetDefaultIdempotencyCoordinator(NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig())) + t.Cleanup(func() { + SetDefaultIdempotencyCoordinator(nil) + }) + + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 9001, + GroupID: 1, + ValidityDays: 30, + Notes: "new", + }) + require.NoError(t, err) + require.NotNil(t, sub) + require.Equal(t, 1, subRepo.createCalls, "semantic idempotent endpoint should not depend on idempotency store availability") +} + +func TestNormalizeAssignValidityDays(t *testing.T) { + require.Equal(t, 30, normalizeAssignValidityDays(0)) + require.Equal(t, 30, normalizeAssignValidityDays(-5)) + require.Equal(t, MaxValidityDays, normalizeAssignValidityDays(MaxValidityDays+100)) + require.Equal(t, 7, normalizeAssignValidityDays(7)) +} + +func TestDetectAssignSemanticConflictCases(t *testing.T) { + start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) + base := &UserSubscription{ + UserID: 1, + GroupID: 1, + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 30), + Notes: "same", + } + + reason, conflict := detectAssignSemanticConflict(base, &AssignSubscriptionInput{ + UserID: 1, + GroupID: 1, + ValidityDays: 30, + Notes: "same", + }) + require.False(t, conflict) + require.Equal(t, "", reason) + + reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{ + UserID: 1, + GroupID: 1, + ValidityDays: 60, + Notes: "same", + }) + require.True(t, conflict) + require.Equal(t, "validity_days_mismatch", reason) + + reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{ + UserID: 1, + GroupID: 1, + ValidityDays: 30, + Notes: "other", + }) + require.True(t, conflict) + require.Equal(t, "notes_mismatch", reason) +} + +func TestAssignSubscriptionGroupTypeValidation(t *testing.T) { + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeStandard}, + } + subRepo := newSubscriptionUserSubRepoStub() + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + + _, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 1, + GroupID: 1, + ValidityDays: 30, + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrGroupNotSubscriptionType), infraerrors.Code(err)) +} + +func strconvFormatInt(v int64) string { + return strconv.FormatInt(v, 10) +} + +func infraerrorsReason(err error) string { + return infraerrors.Reason(err) +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 29ef3662..57e04266 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -6,6 +6,7 @@ import ( "log" "math/rand/v2" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -24,16 +25,17 @@ var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC) const MaxValidityDays = 36500 var ( - ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found") - ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired") - ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended") - ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") - ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") - ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") - ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") - ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") - ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") - ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)") + ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found") + ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired") + ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended") + ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") + ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics") + ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") + ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") + ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") + ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") + ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") + ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)") ) // SubscriptionService 订阅服务 @@ -150,40 +152,10 @@ type AssignSubscriptionInput struct { // AssignSubscription 分配订阅给用户(不允许重复分配) func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { - // 检查分组是否存在且为订阅类型 - group, err := s.groupRepo.GetByID(ctx, input.GroupID) - if err != nil { - return nil, fmt.Errorf("group not found: %w", err) - } - if !group.IsSubscriptionType() { - return nil, ErrGroupNotSubscriptionType - } - - // 检查是否已存在订阅 - exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID) + sub, _, err := s.assignSubscriptionWithReuse(ctx, input) if err != nil { return nil, err } - if exists { - return nil, ErrSubscriptionAlreadyExists - } - - sub, err := s.createSubscription(ctx, input) - if err != nil { - return nil, err - } - - // 失效订阅缓存 - s.InvalidateSubCache(input.UserID, input.GroupID) - if s.billingCacheService != nil { - userID, groupID := input.UserID, input.GroupID - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) - }() - } - return sub, nil } @@ -363,9 +335,12 @@ type BulkAssignSubscriptionInput struct { // BulkAssignResult 批量分配结果 type BulkAssignResult struct { SuccessCount int + CreatedCount int + ReusedCount int FailedCount int Subscriptions []UserSubscription Errors []string + Statuses map[int64]string } // BulkAssignSubscription 批量分配订阅 @@ -373,10 +348,11 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input result := &BulkAssignResult{ Subscriptions: make([]UserSubscription, 0), Errors: make([]string, 0), + Statuses: make(map[int64]string), } for _, userID := range input.UserIDs { - sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{ + sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{ UserID: userID, GroupID: input.GroupID, ValidityDays: input.ValidityDays, @@ -386,15 +362,105 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input if err != nil { result.FailedCount++ result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err)) + result.Statuses[userID] = "failed" } else { result.SuccessCount++ result.Subscriptions = append(result.Subscriptions, *sub) + if reused { + result.ReusedCount++ + result.Statuses[userID] = "reused" + } else { + result.CreatedCount++ + result.Statuses[userID] = "created" + } } } return result, nil } +func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { + // 检查分组是否存在且为订阅类型 + group, err := s.groupRepo.GetByID(ctx, input.GroupID) + if err != nil { + return nil, false, fmt.Errorf("group not found: %w", err) + } + if !group.IsSubscriptionType() { + return nil, false, ErrGroupNotSubscriptionType + } + + // 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅 + exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID) + if err != nil { + return nil, false, err + } + if exists { + sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID) + if getErr != nil { + return nil, false, getErr + } + if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict { + return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{ + "conflict_reason": conflictReason, + }) + } + return sub, true, nil + } + + sub, err := s.createSubscription(ctx, input) + if err != nil { + return nil, false, err + } + + // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) + if s.billingCacheService != nil { + userID, groupID := input.UserID, input.GroupID + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + }() + } + + return sub, false, nil +} + +func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) { + if existing == nil || input == nil { + return "", false + } + + normalizedDays := normalizeAssignValidityDays(input.ValidityDays) + if !existing.StartsAt.IsZero() { + expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays) + if expectedExpiresAt.After(MaxExpiresAt) { + expectedExpiresAt = MaxExpiresAt + } + if !existing.ExpiresAt.Equal(expectedExpiresAt) { + return "validity_days_mismatch", true + } + } + + existingNotes := strings.TrimSpace(existing.Notes) + inputNotes := strings.TrimSpace(input.Notes) + if existingNotes != inputNotes { + return "notes_mismatch", true + } + + return "", false +} + +func normalizeAssignValidityDays(days int) int { + if days <= 0 { + days = 30 + } + if days > MaxValidityDays { + days = MaxValidityDays + } + return days +} + // RevokeSubscription 撤销订阅 func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error { // 先获取订阅信息用于失效缓存 diff --git a/backend/internal/service/system_operation_lock_service.go b/backend/internal/service/system_operation_lock_service.go new file mode 100644 index 00000000..ed5563cd --- /dev/null +++ b/backend/internal/service/system_operation_lock_service.go @@ -0,0 +1,214 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "sync" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +const ( + systemOperationLockScope = "admin.system.operations.global_lock" + systemOperationLockKey = "global-system-operation-lock" +) + +var ( + ErrSystemOperationBusy = infraerrors.Conflict("SYSTEM_OPERATION_BUSY", "another system operation is in progress") +) + +type SystemOperationLock struct { + recordID int64 + operationID string + + stopOnce sync.Once + stopCh chan struct{} +} + +func (l *SystemOperationLock) OperationID() string { + if l == nil { + return "" + } + return l.operationID +} + +type SystemOperationLockService struct { + repo IdempotencyRepository + + lease time.Duration + renewInterval time.Duration + ttl time.Duration +} + +func NewSystemOperationLockService(repo IdempotencyRepository, cfg IdempotencyConfig) *SystemOperationLockService { + lease := cfg.ProcessingTimeout + if lease <= 0 { + lease = 30 * time.Second + } + renewInterval := lease / 3 + if renewInterval < time.Second { + renewInterval = time.Second + } + ttl := cfg.SystemOperationTTL + if ttl <= 0 { + ttl = time.Hour + } + + return &SystemOperationLockService{ + repo: repo, + lease: lease, + renewInterval: renewInterval, + ttl: ttl, + } +} + +func (s *SystemOperationLockService) Acquire(ctx context.Context, operationID string) (*SystemOperationLock, error) { + if s == nil || s.repo == nil { + return nil, ErrIdempotencyStoreUnavail + } + if operationID == "" { + return nil, infraerrors.BadRequest("SYSTEM_OPERATION_ID_REQUIRED", "operation id is required") + } + + now := time.Now() + expiresAt := now.Add(s.ttl) + lockedUntil := now.Add(s.lease) + keyHash := HashIdempotencyKey(systemOperationLockKey) + + record := &IdempotencyRecord{ + Scope: systemOperationLockScope, + IdempotencyKeyHash: keyHash, + RequestFingerprint: operationID, + Status: IdempotencyStatusProcessing, + LockedUntil: &lockedUntil, + ExpiresAt: expiresAt, + } + + owner, err := s.repo.CreateProcessing(ctx, record) + if err != nil { + return nil, ErrIdempotencyStoreUnavail.WithCause(err) + } + if !owner { + existing, getErr := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash) + if getErr != nil { + return nil, ErrIdempotencyStoreUnavail.WithCause(getErr) + } + if existing == nil { + return nil, ErrIdempotencyStoreUnavail + } + if existing.Status == IdempotencyStatusProcessing && existing.LockedUntil != nil && existing.LockedUntil.After(now) { + return nil, s.busyError(existing.RequestFingerprint, existing.LockedUntil, now) + } + reclaimed, reclaimErr := s.repo.TryReclaim( + ctx, + existing.ID, + existing.Status, + now, + lockedUntil, + expiresAt, + ) + if reclaimErr != nil { + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if !reclaimed { + latest, _ := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash) + if latest != nil { + return nil, s.busyError(latest.RequestFingerprint, latest.LockedUntil, now) + } + return nil, ErrSystemOperationBusy + } + record.ID = existing.ID + } + + if record.ID == 0 { + return nil, ErrIdempotencyStoreUnavail + } + + lock := &SystemOperationLock{ + recordID: record.ID, + operationID: operationID, + stopCh: make(chan struct{}), + } + go s.renewLoop(lock) + + return lock, nil +} + +func (s *SystemOperationLockService) Release(ctx context.Context, lock *SystemOperationLock, succeeded bool, failureReason string) error { + if s == nil || s.repo == nil || lock == nil { + return nil + } + + lock.stopOnce.Do(func() { + close(lock.stopCh) + }) + + if ctx == nil { + ctx = context.Background() + } + + expiresAt := time.Now().Add(s.ttl) + if succeeded { + responseBody := fmt.Sprintf(`{"operation_id":"%s","released":true}`, lock.operationID) + return s.repo.MarkSucceeded(ctx, lock.recordID, 200, responseBody, expiresAt) + } + + reason := failureReason + if reason == "" { + reason = "SYSTEM_OPERATION_FAILED" + } + return s.repo.MarkFailedRetryable(ctx, lock.recordID, reason, time.Now(), expiresAt) +} + +func (s *SystemOperationLockService) renewLoop(lock *SystemOperationLock) { + ticker := time.NewTicker(s.renewInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + now := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ok, err := s.repo.ExtendProcessingLock( + ctx, + lock.recordID, + lock.operationID, + now.Add(s.lease), + now.Add(s.ttl), + ) + cancel() + if err != nil { + logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew failed operation_id=%s err=%v", lock.operationID, err) + // 瞬时故障不应导致续租协程退出,下一轮继续尝试续租。 + continue + } + if !ok { + logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew stopped operation_id=%s reason=ownership_lost", lock.operationID) + return + } + case <-lock.stopCh: + return + } + } +} + +func (s *SystemOperationLockService) busyError(operationID string, lockedUntil *time.Time, now time.Time) error { + metadata := make(map[string]string) + if operationID != "" { + metadata["operation_id"] = operationID + } + if lockedUntil != nil { + sec := int(lockedUntil.Sub(now).Seconds()) + if sec <= 0 { + sec = 1 + } + metadata["retry_after"] = strconv.Itoa(sec) + } + if len(metadata) == 0 { + return ErrSystemOperationBusy + } + return ErrSystemOperationBusy.WithMetadata(metadata) +} diff --git a/backend/internal/service/system_operation_lock_service_test.go b/backend/internal/service/system_operation_lock_service_test.go new file mode 100644 index 00000000..cd913ba8 --- /dev/null +++ b/backend/internal/service/system_operation_lock_service_test.go @@ -0,0 +1,305 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestSystemOperationLockService_AcquireBusyAndRelease(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + + lock1, err := svc.Acquire(context.Background(), "op-1") + require.NoError(t, err) + require.NotNil(t, lock1) + + _, err = svc.Acquire(context.Background(), "op-2") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err)) + appErr := infraerrors.FromError(err) + require.Equal(t, "op-1", appErr.Metadata["operation_id"]) + require.NotEmpty(t, appErr.Metadata["retry_after"]) + + require.NoError(t, svc.Release(context.Background(), lock1, true, "")) + + lock2, err := svc.Acquire(context.Background(), "op-2") + require.NoError(t, err) + require.NotNil(t, lock2) + require.NoError(t, svc.Release(context.Background(), lock2, true, "")) +} + +func TestSystemOperationLockService_RenewLease(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 5 * time.Second, + ProcessingTimeout: 1200 * time.Millisecond, + }) + + lock, err := svc.Acquire(context.Background(), "op-renew") + require.NoError(t, err) + require.NotNil(t, lock) + defer func() { + _ = svc.Release(context.Background(), lock, true, "") + }() + + keyHash := HashIdempotencyKey(systemOperationLockKey) + initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash) + require.NotNil(t, initial) + require.NotNil(t, initial.LockedUntil) + initialLockedUntil := *initial.LockedUntil + + time.Sleep(1500 * time.Millisecond) + + updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash) + require.NotNil(t, updated) + require.NotNil(t, updated.LockedUntil) + require.True(t, updated.LockedUntil.After(initialLockedUntil), "locked_until should be renewed while lock is held") +} + +type flakySystemLockRenewRepo struct { + *inMemoryIdempotencyRepo + extendCalls int32 +} + +func (r *flakySystemLockRenewRepo) ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + call := atomic.AddInt32(&r.extendCalls, 1) + if call == 1 { + return false, errors.New("transient extend failure") + } + return r.inMemoryIdempotencyRepo.ExtendProcessingLock(ctx, id, requestFingerprint, newLockedUntil, newExpiresAt) +} + +func TestSystemOperationLockService_RenewLeaseContinuesAfterTransientFailure(t *testing.T) { + repo := &flakySystemLockRenewRepo{inMemoryIdempotencyRepo: newInMemoryIdempotencyRepo()} + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 5 * time.Second, + ProcessingTimeout: 2400 * time.Millisecond, + }) + + lock, err := svc.Acquire(context.Background(), "op-renew-transient") + require.NoError(t, err) + require.NotNil(t, lock) + defer func() { + _ = svc.Release(context.Background(), lock, true, "") + }() + + keyHash := HashIdempotencyKey(systemOperationLockKey) + initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash) + require.NotNil(t, initial) + require.NotNil(t, initial.LockedUntil) + initialLockedUntil := *initial.LockedUntil + + // 首次续租失败后,下一轮应继续尝试并成功更新锁过期时间。 + require.Eventually(t, func() bool { + updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash) + if updated == nil || updated.LockedUntil == nil { + return false + } + return atomic.LoadInt32(&repo.extendCalls) >= 2 && updated.LockedUntil.After(initialLockedUntil) + }, 4*time.Second, 100*time.Millisecond, "renew loop should continue after transient error") +} + +func TestSystemOperationLockService_SameOperationIDRetryWhileRunning(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + + lock1, err := svc.Acquire(context.Background(), "op-same") + require.NoError(t, err) + require.NotNil(t, lock1) + + _, err = svc.Acquire(context.Background(), "op-same") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err)) + appErr := infraerrors.FromError(err) + require.Equal(t, "op-same", appErr.Metadata["operation_id"]) + + require.NoError(t, svc.Release(context.Background(), lock1, true, "")) + + lock2, err := svc.Acquire(context.Background(), "op-same") + require.NoError(t, err) + require.NotNil(t, lock2) + require.NoError(t, svc.Release(context.Background(), lock2, true, "")) +} + +func TestSystemOperationLockService_RecoverAfterLeaseExpired(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 5 * time.Second, + ProcessingTimeout: 300 * time.Millisecond, + }) + + lock1, err := svc.Acquire(context.Background(), "op-crashed") + require.NoError(t, err) + require.NotNil(t, lock1) + + // 模拟实例异常:停止续租,不调用 Release。 + lock1.stopOnce.Do(func() { + close(lock1.stopCh) + }) + + time.Sleep(450 * time.Millisecond) + + lock2, err := svc.Acquire(context.Background(), "op-recovered") + require.NoError(t, err, "expired lease should allow a new operation to reclaim lock") + require.NotNil(t, lock2) + require.NoError(t, svc.Release(context.Background(), lock2, true, "")) +} + +type systemLockRepoStub struct { + createOwner bool + createErr error + existing *IdempotencyRecord + getErr error + reclaimOK bool + reclaimErr error + markSuccErr error + markFailErr error +} + +func (s *systemLockRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + if s.createErr != nil { + return false, s.createErr + } + return s.createOwner, nil +} + +func (s *systemLockRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + if s.getErr != nil { + return nil, s.getErr + } + return cloneRecord(s.existing), nil +} + +func (s *systemLockRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + if s.reclaimErr != nil { + return false, s.reclaimErr + } + return s.reclaimOK, nil +} + +func (s *systemLockRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return true, nil +} + +func (s *systemLockRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return s.markSuccErr +} + +func (s *systemLockRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return s.markFailErr +} + +func (s *systemLockRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, nil +} + +func TestSystemOperationLockService_InputAndStoreErrorBranches(t *testing.T) { + var nilSvc *SystemOperationLockService + _, err := nilSvc.Acquire(context.Background(), "x") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + svc := &SystemOperationLockService{repo: nil} + _, err = svc.Acquire(context.Background(), "x") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + svc = NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + _, err = svc.Acquire(context.Background(), "") + require.Error(t, err) + require.Equal(t, "SYSTEM_OPERATION_ID_REQUIRED", infraerrors.Reason(err)) + + badStore := &systemLockRepoStub{createErr: errors.New("db down")} + svc = NewSystemOperationLockService(badStore, IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + _, err = svc.Acquire(context.Background(), "x") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) +} + +func TestSystemOperationLockService_ExistingNilAndReclaimBranches(t *testing.T) { + now := time.Now() + repo := &systemLockRepoStub{ + createOwner: false, + } + svc := NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + + _, err := svc.Acquire(context.Background(), "op") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 1, + Scope: systemOperationLockScope, + IdempotencyKeyHash: HashIdempotencyKey(systemOperationLockKey), + RequestFingerprint: "other-op", + Status: IdempotencyStatusFailedRetryable, + LockedUntil: ptrTime(now.Add(-time.Second)), + ExpiresAt: now.Add(time.Hour), + } + repo.reclaimErr = errors.New("reclaim failed") + _, err = svc.Acquire(context.Background(), "op") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.reclaimErr = nil + repo.reclaimOK = false + _, err = svc.Acquire(context.Background(), "op") + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err)) +} + +func TestSystemOperationLockService_ReleaseBranchesAndOperationID(t *testing.T) { + require.Equal(t, "", (*SystemOperationLock)(nil).OperationID()) + + svc := NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + lock, err := svc.Acquire(context.Background(), "op") + require.NoError(t, err) + require.NotNil(t, lock) + + require.NoError(t, svc.Release(context.Background(), lock, false, "")) + require.NoError(t, svc.Release(context.Background(), lock, true, "")) + + repo := &systemLockRepoStub{ + createOwner: true, + markSuccErr: errors.New("mark succeeded failed"), + markFailErr: errors.New("mark failed failed"), + } + svc = NewSystemOperationLockService(repo, IdempotencyConfig{ + SystemOperationTTL: 10 * time.Second, + ProcessingTimeout: 2 * time.Second, + }) + lock = &SystemOperationLock{recordID: 1, operationID: "op2", stopCh: make(chan struct{})} + require.Error(t, svc.Release(context.Background(), lock, true, "")) + lock = &SystemOperationLock{recordID: 1, operationID: "op3", stopCh: make(chan struct{})} + require.Error(t, svc.Release(context.Background(), lock, false, "BAD")) + + var nilLockSvc *SystemOperationLockService + require.NoError(t, nilLockSvc.Release(context.Background(), nil, true, "")) + + err = svc.busyError("", nil, time.Now()) + require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err)) +} diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go index 9eeaded4..ee795aa4 100644 --- a/backend/internal/service/usage_cleanup_service.go +++ b/backend/internal/service/usage_cleanup_service.go @@ -320,6 +320,10 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc return err } logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status) + if status == UsageCleanupStatusCanceled { + logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent hit: task=%d operator=%d", taskID, canceledBy) + return nil + } if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") } @@ -329,6 +333,11 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc } if !ok { // 状态可能并发改变 + currentStatus, getErr := s.repo.GetTaskStatus(ctx, taskID) + if getErr == nil && currentStatus == UsageCleanupStatusCanceled { + logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent race hit: task=%d operator=%d", taskID, canceledBy) + return nil + } return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") } logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy) diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index c6c309b6..1f9f4776 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -644,6 +644,23 @@ func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) { require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err)) } +func TestUsageCleanupServiceCancelTaskAlreadyCanceledIsIdempotent(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 7: UsageCleanupStatusCanceled, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 1) + require.NoError(t, err) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.cancelCalls, "already canceled should return success without extra cancel write") +} + func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) { shouldCancel := false repo := &cleanupRepoStub{ diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index bfc2ea48..bd241566 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -225,6 +225,45 @@ func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Confi return svc } +func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig { + idempotencyCfg := DefaultIdempotencyConfig() + if cfg != nil { + if cfg.Idempotency.DefaultTTLSeconds > 0 { + idempotencyCfg.DefaultTTL = time.Duration(cfg.Idempotency.DefaultTTLSeconds) * time.Second + } + if cfg.Idempotency.SystemOperationTTLSeconds > 0 { + idempotencyCfg.SystemOperationTTL = time.Duration(cfg.Idempotency.SystemOperationTTLSeconds) * time.Second + } + if cfg.Idempotency.ProcessingTimeoutSeconds > 0 { + idempotencyCfg.ProcessingTimeout = time.Duration(cfg.Idempotency.ProcessingTimeoutSeconds) * time.Second + } + if cfg.Idempotency.FailedRetryBackoffSeconds > 0 { + idempotencyCfg.FailedRetryBackoff = time.Duration(cfg.Idempotency.FailedRetryBackoffSeconds) * time.Second + } + if cfg.Idempotency.MaxStoredResponseLen > 0 { + idempotencyCfg.MaxStoredResponseLen = cfg.Idempotency.MaxStoredResponseLen + } + idempotencyCfg.ObserveOnly = cfg.Idempotency.ObserveOnly + } + return idempotencyCfg +} + +func ProvideIdempotencyCoordinator(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCoordinator { + coordinator := NewIdempotencyCoordinator(repo, buildIdempotencyConfig(cfg)) + SetDefaultIdempotencyCoordinator(coordinator) + return coordinator +} + +func ProvideSystemOperationLockService(repo IdempotencyRepository, cfg *config.Config) *SystemOperationLockService { + return NewSystemOperationLockService(repo, buildIdempotencyConfig(cfg)) +} + +func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService { + svc := NewIdempotencyCleanupService(repo, cfg) + svc.Start() + return svc +} + // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. func ProvideOpsScheduledReportService( opsService *OpsService, @@ -318,4 +357,7 @@ var ProviderSet = wire.NewSet( NewTotpService, NewErrorPassthroughService, NewDigestSessionStore, + ProvideIdempotencyCoordinator, + ProvideSystemOperationLockService, + ProvideIdempotencyCleanupService, ) diff --git a/backend/migrations/057_add_idempotency_records.sql b/backend/migrations/057_add_idempotency_records.sql new file mode 100644 index 00000000..15738bf9 --- /dev/null +++ b/backend/migrations/057_add_idempotency_records.sql @@ -0,0 +1,27 @@ +-- 幂等记录表:用于关键写接口的请求去重与结果重放 +-- 幂等执行:可重复运行 + +CREATE TABLE IF NOT EXISTS idempotency_records ( + id BIGSERIAL PRIMARY KEY, + scope VARCHAR(128) NOT NULL, + idempotency_key_hash VARCHAR(64) NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + status VARCHAR(32) NOT NULL, + response_status INTEGER, + response_body TEXT, + error_reason VARCHAR(128), + locked_until TIMESTAMPTZ, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_idempotency_records_scope_key + ON idempotency_records (scope, idempotency_key_hash); + +CREATE INDEX IF NOT EXISTS idx_idempotency_records_expires_at + ON idempotency_records (expires_at); + +CREATE INDEX IF NOT EXISTS idx_idempotency_records_status_locked_until + ON idempotency_records (status, locked_until); + diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 73bf77c0..d1c058ec 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -568,6 +568,30 @@ usage_cleanup: # 单次任务最大执行时长(秒) task_timeout_seconds: 1800 +# ============================================================================= +# HTTP 写接口幂等配置 +# Idempotency Configuration +# ============================================================================= +idempotency: + # Observe-only 模式: + # true: 观察期,不带 Idempotency-Key 仍放行(但会记录) + # false: 强制期,不带 Idempotency-Key 直接拒绝(仅对接入幂等保护的接口生效) + observe_only: true + # 关键写接口幂等记录 TTL(秒) + default_ttl_seconds: 86400 + # 系统操作接口(update/rollback/restart)幂等记录 TTL(秒) + system_operation_ttl_seconds: 3600 + # processing 锁超时(秒) + processing_timeout_seconds: 30 + # 可重试失败退避窗口(秒) + failed_retry_backoff_seconds: 5 + # 持久化响应体最大长度(字节) + max_stored_response_len: 65536 + # 过期幂等记录清理周期(秒) + cleanup_interval_seconds: 60 + # 每轮清理最大删除条数 + cleanup_batch_size: 500 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置