feat(idempotency): 为关键写接口接入幂等并完善并发容错
This commit is contained in:
@@ -74,6 +74,7 @@ type Config struct {
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
@@ -137,6 +138,25 @@ type UpdateConfig struct {
|
||||
ProxyURL string `mapstructure:"proxy_url"`
|
||||
}
|
||||
|
||||
type IdempotencyConfig struct {
|
||||
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
|
||||
ObserveOnly bool `mapstructure:"observe_only"`
|
||||
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。
|
||||
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
|
||||
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。
|
||||
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
|
||||
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
|
||||
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
|
||||
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
|
||||
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
|
||||
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
|
||||
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
|
||||
// CleanupIntervalSeconds 过期记录清理周期(秒)。
|
||||
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
|
||||
// CleanupBatchSize 每次清理的最大记录数。
|
||||
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
|
||||
}
|
||||
|
||||
type LinuxDoConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
@@ -1117,6 +1137,16 @@ func setDefaults() {
|
||||
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
|
||||
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
|
||||
|
||||
// Idempotency
|
||||
viper.SetDefault("idempotency.observe_only", true)
|
||||
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
|
||||
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
|
||||
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
|
||||
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
|
||||
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
|
||||
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
|
||||
viper.SetDefault("idempotency.cleanup_batch_size", 500)
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||
@@ -1560,6 +1590,27 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Idempotency.DefaultTTLSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.MaxStoredResponseLen <= 0 {
|
||||
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
|
||||
}
|
||||
if c.Idempotency.CleanupIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.CleanupBatchSize <= 0 {
|
||||
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
|
||||
}
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
|
||||
@@ -75,6 +75,42 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Idempotency.ObserveOnly {
|
||||
t.Fatalf("Idempotency.ObserveOnly = false, want true")
|
||||
}
|
||||
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
|
||||
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
|
||||
}
|
||||
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
|
||||
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
|
||||
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Idempotency.ObserveOnly {
|
||||
t.Fatalf("Idempotency.ObserveOnly = true, want false")
|
||||
}
|
||||
if cfg.Idempotency.DefaultTTLSeconds != 600 {
|
||||
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||
|
||||
@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
if err := validateDataHeader(dataPayload); err != nil {
|
||||
if err := validateDataHeader(req.Data); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
return h.importData(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
|
||||
skipDefaultGroupBind := true
|
||||
if req.SkipDefaultGroupBind != nil {
|
||||
skipDefaultGroupBind = *req.SkipDefaultGroupBind
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
result := DataImportResult{}
|
||||
existingProxies, err := h.listAllProxies(c.Request.Context())
|
||||
|
||||
existingProxies, err := h.listAllProxies(ctx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return result, err
|
||||
}
|
||||
|
||||
proxyKeyToID := make(map[string]int64, len(existingProxies))
|
||||
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
proxyKeyToID[key] = existingID
|
||||
result.ProxyReused++
|
||||
if normalizedStatus != "" {
|
||||
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
|
||||
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: defaultProxyName(item.Name),
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
if createErr != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
Message: createErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.ProxyCreated++
|
||||
|
||||
if normalizedStatus != "" && normalizedStatus != created.Status {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
|
||||
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
SkipDefaultGroupBind: skipDefaultGroupBind,
|
||||
}
|
||||
|
||||
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
|
||||
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.AccountCreated++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
|
||||
@@ -405,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return h.buildAccountResponseWithRuntime(ctx, account), nil
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否为混合渠道错误
|
||||
@@ -440,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
// Update handles updating an account
|
||||
@@ -838,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
return gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type idempotencyStoreUnavailableMode int
|
||||
|
||||
const (
|
||||
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
|
||||
idempotencyStoreUnavailableFailOpen
|
||||
)
|
||||
|
||||
func executeAdminIdempotent(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) (*service.IdempotencyExecuteResult, error) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONWithMode(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
mode idempotencyStoreUnavailableMode,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
strategy := "fail_close"
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
strategy = "fail_open"
|
||||
}
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
data, fallbackErr := execute(c.Request.Context())
|
||||
if fallbackErr != nil {
|
||||
response.ErrorFrom(c, fallbackErr)
|
||||
return
|
||||
}
|
||||
c.Header("X-Idempotency-Degraded", "store-unavailable")
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type storeUnavailableRepoStub struct{}
|
||||
|
||||
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-2")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
|
||||
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
|
||||
}
|
||||
|
||||
type memoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
|
||||
return &memoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status1, _ = call()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status2, _ = call()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.ProxyFromService(proxy), nil
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Update handles updating a proxy
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
SubscriptionID int64 `json:"subscription_id"`
|
||||
Body AdjustSubscriptionRequest `json:"body"`
|
||||
}{
|
||||
SubscriptionID: subscriptionID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,12 +18,14 @@ import (
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
lockSvc *service.SystemOperationLockService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: updateSvc,
|
||||
lockSvc: lockSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "update")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
|
||||
releaseReason = "SYSTEM_UPDATE_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "rollback")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
releaseReason = "SYSTEM_ROLLBACK_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
operationID := buildSystemOperationID(c, "restart")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release("", succeeded)
|
||||
}()
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
succeeded = true
|
||||
return gin.H{
|
||||
"message": "Service restart initiated",
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SystemHandler) acquireSystemLock(
|
||||
ctx context.Context,
|
||||
operationID string,
|
||||
) (*service.SystemOperationLock, func(string, bool), error) {
|
||||
if h.lockSvc == nil {
|
||||
return nil, nil, service.ErrIdempotencyStoreUnavail
|
||||
}
|
||||
lock, err := h.lockSvc.Acquire(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
release := func(reason string, succeeded bool) {
|
||||
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
|
||||
}
|
||||
return lock, release, nil
|
||||
}
|
||||
|
||||
func buildSystemOperationID(c *gin.Context, operation string) string {
|
||||
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
|
||||
if key == "" {
|
||||
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
|
||||
hash := service.HashIdempotencyKey(seed)
|
||||
if len(hash) > 24 {
|
||||
hash = hash[:24]
|
||||
}
|
||||
return "sysop-" + hash
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
OperatorID int64 `json:"operator_id"`
|
||||
Body CreateUsageCleanupTaskRequest `json:"body"`
|
||||
}{
|
||||
OperatorID: subject.UserID,
|
||||
Body: req,
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
return dto.UsageCleanupTaskFromService(task), nil
|
||||
})
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Body UpdateBalanceRequest `json:"body"`
|
||||
}{
|
||||
UserID: userID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserFromServiceAdmin(user), nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
if req.Quota != nil {
|
||||
svcReq.Quota = *req.Quota
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.APIKeyFromService(key), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -542,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
statuses := make(map[string]string, len(r.Statuses))
|
||||
for userID, status := range r.Statuses {
|
||||
statuses[strconv.FormatInt(userID, 10)] = status
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
CreatedCount: r.CreatedCount,
|
||||
ReusedCount: r.ReusedCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
Statuses: statuses,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -395,9 +395,12 @@ type AdminUserSubscription struct {
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
CreatedCount int `json:"created_count"`
|
||||
ReusedCount int `json:"reused_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
Statuses map[string]string `json:"statuses,omitempty"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
65
backend/internal/handler/idempotency_helper.go
Normal file
65
backend/internal/handler/idempotency_helper.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func executeUserIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
|
||||
actorScope := "user:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/idempotency_helper_test.go
Normal file
285
backend/internal/handler/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userStoreUnavailableRepoStub struct{}
|
||||
|
||||
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
type userMemoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
|
||||
return &userMemoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func withUserSubject(userID int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(1))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(2))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "k1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newUserMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(3))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-user-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() { defer wg.Done(); status1, _ = call() }()
|
||||
go func() { defer wg.Done(); status2, _ = call() }()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
@@ -53,8 +53,8 @@ func ProvideAdminHandlers(
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService, lockService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
@@ -77,6 +77,8 @@ func ProvideHandlers(
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
_ *service.IdempotencyCleanupService,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
|
||||
237
backend/internal/repository/idempotency_repo.go
Normal file
237
backend/internal/repository/idempotency_repo.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type idempotencyRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
|
||||
return &idempotencyRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
if record == nil {
|
||||
return false, nil
|
||||
}
|
||||
query := `
|
||||
INSERT INTO idempotency_records (
|
||||
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
var createdAt time.Time
|
||||
var updatedAt time.Time
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{
|
||||
record.Scope,
|
||||
record.IdempotencyKeyHash,
|
||||
record.RequestFingerprint,
|
||||
record.Status,
|
||||
record.LockedUntil,
|
||||
record.ExpiresAt,
|
||||
}, &record.ID, &createdAt, &updatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
record.CreatedAt = createdAt
|
||||
record.UpdatedAt = updatedAt
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT
|
||||
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
|
||||
response_body, error_reason, locked_until, expires_at, created_at, updated_at
|
||||
FROM idempotency_records
|
||||
WHERE scope = $1 AND idempotency_key_hash = $2
|
||||
`
|
||||
record := &service.IdempotencyRecord{}
|
||||
var responseStatus sql.NullInt64
|
||||
var responseBody sql.NullString
|
||||
var errorReason sql.NullString
|
||||
var lockedUntil sql.NullTime
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
|
||||
&record.ID,
|
||||
&record.Scope,
|
||||
&record.IdempotencyKeyHash,
|
||||
&record.RequestFingerprint,
|
||||
&record.Status,
|
||||
&responseStatus,
|
||||
&responseBody,
|
||||
&errorReason,
|
||||
&lockedUntil,
|
||||
&record.ExpiresAt,
|
||||
&record.CreatedAt,
|
||||
&record.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if responseStatus.Valid {
|
||||
v := int(responseStatus.Int64)
|
||||
record.ResponseStatus = &v
|
||||
}
|
||||
if responseBody.Valid {
|
||||
v := responseBody.String
|
||||
record.ResponseBody = &v
|
||||
}
|
||||
if errorReason.Valid {
|
||||
v := errorReason.String
|
||||
record.ErrorReason = &v
|
||||
}
|
||||
if lockedUntil.Valid {
|
||||
v := lockedUntil.Time
|
||||
record.LockedUntil = &v
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) TryReclaim(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
fromStatus string,
|
||||
now, newLockedUntil, newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
locked_until = $3,
|
||||
error_reason = NULL,
|
||||
updated_at = NOW(),
|
||||
expires_at = $4
|
||||
WHERE id = $1
|
||||
AND status = $5
|
||||
AND (locked_until IS NULL OR locked_until <= $6)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusProcessing,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
fromStatus,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) ExtendProcessingLock(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
requestFingerprint string,
|
||||
newLockedUntil,
|
||||
newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET locked_until = $2,
|
||||
expires_at = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = $4
|
||||
AND request_fingerprint = $5
|
||||
`
|
||||
res, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
id,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
service.IdempotencyStatusProcessing,
|
||||
requestFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
response_status = $3,
|
||||
response_body = $4,
|
||||
error_reason = NULL,
|
||||
locked_until = NULL,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusSucceeded,
|
||||
responseStatus,
|
||||
responseBody,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
error_reason = $3,
|
||||
locked_until = $4,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
errorReason,
|
||||
lockedUntil,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
}
|
||||
query := `
|
||||
WITH victims AS (
|
||||
SELECT id
|
||||
FROM idempotency_records
|
||||
WHERE expires_at <= $1
|
||||
ORDER BY expires_at ASC
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM idempotency_records
|
||||
WHERE id IN (SELECT id FROM victims)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query, now, limit)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
144
backend/internal/repository/idempotency_repo_integration_test.go
Normal file
144
backend/internal/repository/idempotency_repo_integration_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-create"),
|
||||
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash"),
|
||||
RequestFingerprint: uniqueTestValue(t, "idem-fp"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
require.NotZero(t, record.ID)
|
||||
|
||||
duplicate := &service.IdempotencyRecord{
|
||||
Scope: record.Scope,
|
||||
IdempotencyKeyHash: record.IdempotencyKeyHash,
|
||||
RequestFingerprint: uniqueTestValue(t, "idem-fp-other"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err = repo.CreateProcessing(ctx, duplicate)
|
||||
require.NoError(t, err)
|
||||
require.False(t, owner, "same scope+key hash should be de-duplicated")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
|
||||
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-reclaim"),
|
||||
RequestFingerprint: uniqueTestValue(t, "idem-fp-reclaim"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(-2*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
newLockedUntil := now.Add(20 * time.Second)
|
||||
reclaimed, err := repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
newLockedUntil,
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
|
||||
require.NotNil(t, got.LockedUntil)
|
||||
require.True(t, got.LockedUntil.After(now))
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(20*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
reclaimed, err = repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
now.Add(40*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, reclaimed, "within lock window should not reclaim")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-success"),
|
||||
IdempotencyKeyHash: uniqueTestValue(t, "idem-hash-success"),
|
||||
RequestFingerprint: uniqueTestValue(t, "idem-fp-success"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
|
||||
require.NotNil(t, got.ResponseStatus)
|
||||
require.Equal(t, 200, *got.ResponseStatus)
|
||||
require.NotNil(t, got.ResponseBody)
|
||||
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
|
||||
require.Nil(t, got.LockedUntil)
|
||||
}
|
||||
|
||||
func ptrTime(v time.Time) *time.Time {
|
||||
return &v
|
||||
}
|
||||
@@ -60,6 +60,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
|
||||
@@ -35,6 +35,8 @@ var (
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyLastUsedMinTouch = 30 * time.Second
|
||||
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
|
||||
apiKeyLastUsedFailBackoff = 5 * time.Second
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
@@ -129,7 +131,7 @@ type APIKeyService struct {
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> time.Time
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
@@ -574,7 +576,7 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
|
||||
now := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -582,15 +584,16 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
|
||||
latest := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
|
||||
return nil, fmt.Errorf("touch api key last used: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Store(keyID, latest)
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
|
||||
return nil, nil
|
||||
})
|
||||
return err
|
||||
|
||||
@@ -79,8 +79,27 @@ func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
|
||||
require.ErrorContains(t, err, "touch api key last used")
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
|
||||
_, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.False(t, ok, "failed touch should not update debounce cache")
|
||||
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.True(t, ok, "failed touch should still update retry debounce cache")
|
||||
_, isTime := cached.(time.Time)
|
||||
require.True(t, isTime)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("db write failed")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
firstErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.Error(t, firstErr)
|
||||
require.ErrorContains(t, firstErr, "touch api key last used")
|
||||
|
||||
secondErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
|
||||
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
|
||||
}
|
||||
|
||||
type touchSingleflightRepo struct {
|
||||
|
||||
471
backend/internal/service/idempotency.go
Normal file
471
backend/internal/service/idempotency.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const (
|
||||
IdempotencyStatusProcessing = "processing"
|
||||
IdempotencyStatusSucceeded = "succeeded"
|
||||
IdempotencyStatusFailedRetryable = "failed_retryable"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required")
|
||||
ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid")
|
||||
ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload")
|
||||
ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing")
|
||||
ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window")
|
||||
ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable")
|
||||
ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload")
|
||||
)
|
||||
|
||||
type IdempotencyRecord struct {
|
||||
ID int64
|
||||
Scope string
|
||||
IdempotencyKeyHash string
|
||||
RequestFingerprint string
|
||||
Status string
|
||||
ResponseStatus *int
|
||||
ResponseBody *string
|
||||
ErrorReason *string
|
||||
LockedUntil *time.Time
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type IdempotencyRepository interface {
|
||||
CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error)
|
||||
GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error)
|
||||
TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error)
|
||||
ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error)
|
||||
MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error
|
||||
MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error
|
||||
DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error)
|
||||
}
|
||||
|
||||
type IdempotencyConfig struct {
|
||||
DefaultTTL time.Duration
|
||||
SystemOperationTTL time.Duration
|
||||
ProcessingTimeout time.Duration
|
||||
FailedRetryBackoff time.Duration
|
||||
MaxStoredResponseLen int
|
||||
ObserveOnly bool
|
||||
}
|
||||
|
||||
func DefaultIdempotencyConfig() IdempotencyConfig {
|
||||
return IdempotencyConfig{
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
SystemOperationTTL: 1 * time.Hour,
|
||||
ProcessingTimeout: 30 * time.Second,
|
||||
FailedRetryBackoff: 5 * time.Second,
|
||||
MaxStoredResponseLen: 64 * 1024,
|
||||
ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断
|
||||
}
|
||||
}
|
||||
|
||||
type IdempotencyExecuteOptions struct {
|
||||
Scope string
|
||||
ActorScope string
|
||||
Method string
|
||||
Route string
|
||||
IdempotencyKey string
|
||||
Payload any
|
||||
TTL time.Duration
|
||||
RequireKey bool
|
||||
}
|
||||
|
||||
type IdempotencyExecuteResult struct {
|
||||
Data any
|
||||
Replayed bool
|
||||
}
|
||||
|
||||
type IdempotencyCoordinator struct {
|
||||
repo IdempotencyRepository
|
||||
cfg IdempotencyConfig
|
||||
}
|
||||
|
||||
var (
|
||||
defaultIdempotencyMu sync.RWMutex
|
||||
defaultIdempotencySvc *IdempotencyCoordinator
|
||||
)
|
||||
|
||||
func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) {
|
||||
defaultIdempotencyMu.Lock()
|
||||
defaultIdempotencySvc = svc
|
||||
defaultIdempotencyMu.Unlock()
|
||||
}
|
||||
|
||||
func DefaultIdempotencyCoordinator() *IdempotencyCoordinator {
|
||||
defaultIdempotencyMu.RLock()
|
||||
defer defaultIdempotencyMu.RUnlock()
|
||||
return defaultIdempotencySvc
|
||||
}
|
||||
|
||||
func DefaultWriteIdempotencyTTL() time.Duration {
|
||||
defaultTTL := DefaultIdempotencyConfig().DefaultTTL
|
||||
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 {
|
||||
return coordinator.cfg.DefaultTTL
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
func DefaultSystemOperationIdempotencyTTL() time.Duration {
|
||||
defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL
|
||||
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 {
|
||||
return coordinator.cfg.SystemOperationTTL
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator {
|
||||
return &IdempotencyCoordinator{
|
||||
repo: repo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeIdempotencyKey(raw string) (string, error) {
|
||||
key := strings.TrimSpace(raw)
|
||||
if key == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(key) > 128 {
|
||||
return "", ErrIdempotencyKeyInvalid
|
||||
}
|
||||
for _, r := range key {
|
||||
if r < 33 || r > 126 {
|
||||
return "", ErrIdempotencyKeyInvalid
|
||||
}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func HashIdempotencyKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) {
|
||||
if method == "" {
|
||||
method = "POST"
|
||||
}
|
||||
if route == "" {
|
||||
route = "/"
|
||||
}
|
||||
if actorScope == "" {
|
||||
actorScope = "anonymous"
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", ErrIdempotencyInvalidPayload.WithCause(err)
|
||||
}
|
||||
sum := sha256.Sum256([]byte(
|
||||
strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw),
|
||||
))
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func RetryAfterSecondsFromError(err error) int {
|
||||
appErr := new(infraerrors.ApplicationError)
|
||||
if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil {
|
||||
return 0
|
||||
}
|
||||
v := strings.TrimSpace(appErr.Metadata["retry_after"])
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
seconds, convErr := strconv.Atoi(v)
|
||||
if convErr != nil || seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return seconds
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) Execute(
|
||||
ctx context.Context,
|
||||
opts IdempotencyExecuteOptions,
|
||||
execute func(context.Context) (any, error),
|
||||
) (*IdempotencyExecuteResult, error) {
|
||||
if execute == nil {
|
||||
return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil")
|
||||
}
|
||||
|
||||
key, err := NormalizeIdempotencyKey(opts.IdempotencyKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key == "" {
|
||||
if opts.RequireKey && !c.cfg.ObserveOnly {
|
||||
return nil, ErrIdempotencyKeyRequired
|
||||
}
|
||||
data, execErr := execute(ctx)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return &IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
if c.repo == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil")
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
if opts.Scope == "" {
|
||||
return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required")
|
||||
}
|
||||
|
||||
fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl := opts.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = c.cfg.DefaultTTL
|
||||
}
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(ttl)
|
||||
lockedUntil := now.Add(c.cfg.ProcessingTimeout)
|
||||
keyHash := HashIdempotencyKey(key)
|
||||
|
||||
record := &IdempotencyRecord{
|
||||
Scope: opts.Scope,
|
||||
IdempotencyKeyHash: keyHash,
|
||||
RequestFingerprint: fingerprint,
|
||||
Status: IdempotencyStatusProcessing,
|
||||
LockedUntil: &lockedUntil,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
owner, err := c.repo.CreateProcessing(ctx, record)
|
||||
if err != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "create_processing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
|
||||
}
|
||||
if owner {
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{
|
||||
"claim_mode": "new",
|
||||
})
|
||||
}
|
||||
if !owner {
|
||||
existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
|
||||
if getErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "get_existing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
|
||||
}
|
||||
if existing == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "missing_existing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if existing.RequestFingerprint != fingerprint {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
reclaimedByExpired := false
|
||||
if !existing.ExpiresAt.After(now) {
|
||||
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt)
|
||||
if reclaimErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{
|
||||
"operation": "try_reclaim_expired",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if taken {
|
||||
reclaimedByExpired = true
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{
|
||||
"claim_mode": "expired_reclaim",
|
||||
})
|
||||
record.ID = existing.ID
|
||||
} else {
|
||||
latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
|
||||
if latestErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "get_existing_after_expired_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr)
|
||||
}
|
||||
if latest == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "missing_existing_after_expired_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if latest.RequestFingerprint != fingerprint {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
existing = latest
|
||||
}
|
||||
}
|
||||
|
||||
if !reclaimedByExpired {
|
||||
switch existing.Status {
|
||||
case IdempotencyStatusSucceeded:
|
||||
data, parseErr := c.decodeStoredResponse(existing.ResponseBody)
|
||||
if parseErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{
|
||||
"operation": "decode_stored_response",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr)
|
||||
}
|
||||
recordIdempotencyReplay(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil)
|
||||
return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil
|
||||
case IdempotencyStatusProcessing:
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil)
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
|
||||
case IdempotencyStatusFailedRetryable:
|
||||
if existing.LockedUntil != nil && existing.LockedUntil.After(now) {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"})
|
||||
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil)
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now)
|
||||
}
|
||||
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt)
|
||||
if reclaimErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{
|
||||
"operation": "try_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if !taken {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{
|
||||
"conflict": "reclaim_race",
|
||||
})
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
|
||||
}
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{
|
||||
"claim_mode": "reclaim",
|
||||
})
|
||||
record.ID = existing.ID
|
||||
default:
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{
|
||||
"status": existing.Status,
|
||||
})
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if record.ID == 0 {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "record_id_missing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
execStart := time.Now()
|
||||
defer func() {
|
||||
recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil)
|
||||
}()
|
||||
|
||||
data, execErr := execute(ctx)
|
||||
if execErr != nil {
|
||||
backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff)
|
||||
reason := infraerrors.Reason(execErr)
|
||||
if reason == "" {
|
||||
reason = "EXECUTION_FAILED"
|
||||
}
|
||||
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{
|
||||
"reason": reason,
|
||||
})
|
||||
if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "mark_failed_retryable",
|
||||
})
|
||||
}
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
storedBody, marshalErr := c.marshalStoredResponse(data)
|
||||
if marshalErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "marshal_response",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr)
|
||||
}
|
||||
if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "mark_succeeded",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(markErr)
|
||||
}
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil)
|
||||
|
||||
return &IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error {
|
||||
if lockedUntil == nil {
|
||||
return base
|
||||
}
|
||||
sec := int(lockedUntil.Sub(now).Seconds())
|
||||
if sec <= 0 {
|
||||
sec = 1
|
||||
}
|
||||
return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)})
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) {
|
||||
raw, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
redacted := logredact.RedactText(string(raw))
|
||||
if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen {
|
||||
redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)"
|
||||
}
|
||||
return redacted, nil
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) {
|
||||
if stored == nil || strings.TrimSpace(*stored) == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(*stored), &out); err != nil {
|
||||
return nil, fmt.Errorf("decode stored response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
91
backend/internal/service/idempotency_cleanup_service.go
Normal file
91
backend/internal/service/idempotency_cleanup_service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。
|
||||
type IdempotencyCleanupService struct {
|
||||
repo IdempotencyRepository
|
||||
interval time.Duration
|
||||
batch int
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
|
||||
interval := 60 * time.Second
|
||||
batch := 500
|
||||
if cfg != nil {
|
||||
if cfg.Idempotency.CleanupIntervalSeconds > 0 {
|
||||
interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.CleanupBatchSize > 0 {
|
||||
batch = cfg.Idempotency.CleanupBatchSize
|
||||
}
|
||||
}
|
||||
return &IdempotencyCleanupService{
|
||||
repo: repo,
|
||||
interval: interval,
|
||||
batch: batch,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) Start() {
|
||||
if s == nil || s.repo == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch)
|
||||
go s.runLoop()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) runLoop() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动后先清理一轮,防止重启后积压。
|
||||
s.cleanupOnce()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.cleanupOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) cleanupOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err)
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted)
|
||||
}
|
||||
}
|
||||
69
backend/internal/service/idempotency_cleanup_service_test.go
Normal file
69
backend/internal/service/idempotency_cleanup_service_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type idempotencyCleanupRepoStub struct {
|
||||
deleteCalls int
|
||||
lastLimit int
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) {
|
||||
r.deleteCalls++
|
||||
r.lastLimit = limit
|
||||
if r.deleteErr != nil {
|
||||
return 0, r.deleteErr
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) {
|
||||
repo := &idempotencyCleanupRepoStub{}
|
||||
cfg := &config.Config{
|
||||
Idempotency: config.IdempotencyConfig{
|
||||
CleanupIntervalSeconds: 7,
|
||||
CleanupBatchSize: 321,
|
||||
},
|
||||
}
|
||||
svc := NewIdempotencyCleanupService(repo, cfg)
|
||||
require.Equal(t, 7*time.Second, svc.interval)
|
||||
require.Equal(t, 321, svc.batch)
|
||||
}
|
||||
|
||||
func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) {
|
||||
repo := &idempotencyCleanupRepoStub{}
|
||||
svc := NewIdempotencyCleanupService(repo, &config.Config{
|
||||
Idempotency: config.IdempotencyConfig{
|
||||
CleanupBatchSize: 99,
|
||||
},
|
||||
})
|
||||
|
||||
svc.cleanupOnce()
|
||||
require.Equal(t, 1, repo.deleteCalls)
|
||||
require.Equal(t, 99, repo.lastLimit)
|
||||
}
|
||||
171
backend/internal/service/idempotency_observability.go
Normal file
171
backend/internal/service/idempotency_observability.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。
|
||||
type IdempotencyMetricsSnapshot struct {
|
||||
ClaimTotal uint64 `json:"claim_total"`
|
||||
ReplayTotal uint64 `json:"replay_total"`
|
||||
ConflictTotal uint64 `json:"conflict_total"`
|
||||
RetryBackoffTotal uint64 `json:"retry_backoff_total"`
|
||||
ProcessingDurationCount uint64 `json:"processing_duration_count"`
|
||||
ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"`
|
||||
StoreUnavailableTotal uint64 `json:"store_unavailable_total"`
|
||||
}
|
||||
|
||||
type idempotencyMetrics struct {
|
||||
claimTotal atomic.Uint64
|
||||
replayTotal atomic.Uint64
|
||||
conflictTotal atomic.Uint64
|
||||
retryBackoffTotal atomic.Uint64
|
||||
processingDurationCount atomic.Uint64
|
||||
processingDurationMicros atomic.Uint64
|
||||
storeUnavailableTotal atomic.Uint64
|
||||
}
|
||||
|
||||
var defaultIdempotencyMetrics idempotencyMetrics
|
||||
|
||||
// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。
|
||||
func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot {
|
||||
totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load()
|
||||
return IdempotencyMetricsSnapshot{
|
||||
ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(),
|
||||
ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(),
|
||||
ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(),
|
||||
RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(),
|
||||
ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(),
|
||||
ProcessingDurationTotalMs: float64(totalMicros) / 1000.0,
|
||||
StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.claimTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.replayTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.conflictTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.retryBackoffTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) {
|
||||
if duration < 0 {
|
||||
duration = 0
|
||||
}
|
||||
defaultIdempotencyMetrics.processingDurationCount.Add(1)
|
||||
defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds()))
|
||||
logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs)
|
||||
}
|
||||
|
||||
// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。
|
||||
func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) {
|
||||
defaultIdempotencyMetrics.storeUnavailableTotal.Add(1)
|
||||
attrs := map[string]string{}
|
||||
if strategy != "" {
|
||||
attrs["strategy"] = strategy
|
||||
}
|
||||
logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) {
|
||||
var b strings.Builder
|
||||
builderWriteString(&b, "[IdempotencyAudit]")
|
||||
builderWriteString(&b, " endpoint=")
|
||||
builderWriteString(&b, safeAuditField(endpoint))
|
||||
builderWriteString(&b, " scope=")
|
||||
builderWriteString(&b, safeAuditField(scope))
|
||||
builderWriteString(&b, " key_hash=")
|
||||
builderWriteString(&b, safeAuditField(keyHash))
|
||||
builderWriteString(&b, " state_transition=")
|
||||
builderWriteString(&b, safeAuditField(stateTransition))
|
||||
builderWriteString(&b, " replayed=")
|
||||
builderWriteString(&b, strconv.FormatBool(replayed))
|
||||
if len(attrs) > 0 {
|
||||
appendSortedAttrs(&b, attrs)
|
||||
}
|
||||
logger.LegacyPrintf("service.idempotency", "%s", b.String())
|
||||
}
|
||||
|
||||
func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) {
|
||||
var b strings.Builder
|
||||
builderWriteString(&b, "[IdempotencyMetric]")
|
||||
builderWriteString(&b, " name=")
|
||||
builderWriteString(&b, safeAuditField(name))
|
||||
builderWriteString(&b, " endpoint=")
|
||||
builderWriteString(&b, safeAuditField(endpoint))
|
||||
builderWriteString(&b, " scope=")
|
||||
builderWriteString(&b, safeAuditField(scope))
|
||||
builderWriteString(&b, " value=")
|
||||
builderWriteString(&b, safeAuditField(value))
|
||||
if len(attrs) > 0 {
|
||||
appendSortedAttrs(&b, attrs)
|
||||
}
|
||||
logger.LegacyPrintf("service.idempotency", "%s", b.String())
|
||||
}
|
||||
|
||||
func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) {
|
||||
if len(attrs) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(attrs))
|
||||
for k := range attrs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
builderWriteByte(builder, ' ')
|
||||
builderWriteString(builder, k)
|
||||
builderWriteByte(builder, '=')
|
||||
builderWriteString(builder, safeAuditField(attrs[k]))
|
||||
}
|
||||
}
|
||||
|
||||
func safeAuditField(v string) string {
|
||||
value := strings.TrimSpace(v)
|
||||
if value == "" {
|
||||
return "-"
|
||||
}
|
||||
// 日志按 key=value 输出,替换空白避免解析歧义。
|
||||
value = strings.ReplaceAll(value, "\n", "_")
|
||||
value = strings.ReplaceAll(value, "\r", "_")
|
||||
value = strings.ReplaceAll(value, "\t", "_")
|
||||
value = strings.ReplaceAll(value, " ", "_")
|
||||
return value
|
||||
}
|
||||
|
||||
func resetIdempotencyMetricsForTest() {
|
||||
defaultIdempotencyMetrics.claimTotal.Store(0)
|
||||
defaultIdempotencyMetrics.replayTotal.Store(0)
|
||||
defaultIdempotencyMetrics.conflictTotal.Store(0)
|
||||
defaultIdempotencyMetrics.retryBackoffTotal.Store(0)
|
||||
defaultIdempotencyMetrics.processingDurationCount.Store(0)
|
||||
defaultIdempotencyMetrics.processingDurationMicros.Store(0)
|
||||
defaultIdempotencyMetrics.storeUnavailableTotal.Store(0)
|
||||
}
|
||||
|
||||
func builderWriteString(builder *strings.Builder, value string) {
|
||||
_, _ = builder.WriteString(value)
|
||||
}
|
||||
|
||||
func builderWriteByte(builder *strings.Builder, value byte) {
|
||||
_ = builder.WriteByte(value)
|
||||
}
|
||||
805
backend/internal/service/idempotency_test.go
Normal file
805
backend/internal/service/idempotency_test.go
Normal file
@@ -0,0 +1,805 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type inMemoryIdempotencyRepo struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*IdempotencyRecord
|
||||
}
|
||||
|
||||
func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo {
|
||||
return &inMemoryIdempotencyRepo{
|
||||
nextID: 1,
|
||||
data: make(map[string]*IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) key(scope, hash string) string {
|
||||
return scope + "|" + hash
|
||||
}
|
||||
|
||||
func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
rec := cloneRecord(record)
|
||||
rec.ID = r.nextID
|
||||
rec.CreatedAt = time.Now()
|
||||
rec.UpdatedAt = rec.CreatedAt
|
||||
r.nextID++
|
||||
r.data[k] = rec
|
||||
record.ID = rec.ID
|
||||
record.CreatedAt = rec.CreatedAt
|
||||
record.UpdatedAt = rec.UpdatedAt
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return cloneRecord(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
rec.UpdatedAt = time.Now()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
rec.ErrorReason = nil
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
return nil
|
||||
}
|
||||
return errors.New("record not found")
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return errors.New("record not found")
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var deleted int64
|
||||
for k, rec := range r.data {
|
||||
if !rec.ExpiresAt.After(now) {
|
||||
delete(r.data, k)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_RequireKey(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.ObserveOnly = false
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "admin:1",
|
||||
RequireKey: true,
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
execCount := 0
|
||||
exec := func(ctx context.Context) (any, error) {
|
||||
execCount++
|
||||
return map[string]any{"count": execCount}, nil
|
||||
}
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}
|
||||
|
||||
first, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.False(t, first.Replayed)
|
||||
|
||||
second, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.True(t, second.Replayed)
|
||||
require.Equal(t, 1, execCount, "second request should replay without executing business logic")
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ClaimTotal)
|
||||
require.Equal(t, uint64(1), metrics.ReplayTotal)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.expired",
|
||||
Method: "POST",
|
||||
Route: "/test/expired",
|
||||
ActorScope: "user:99",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "expired-case",
|
||||
Payload: map[string]any{"k": "v"},
|
||||
}
|
||||
|
||||
execCount := 0
|
||||
exec := func(ctx context.Context) (any, error) {
|
||||
execCount++
|
||||
return map[string]any{"count": execCount}, nil
|
||||
}
|
||||
|
||||
first, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, first)
|
||||
require.False(t, first.Replayed)
|
||||
require.Equal(t, 1, execCount)
|
||||
|
||||
keyHash := HashIdempotencyKey(opts.IdempotencyKey)
|
||||
repo.mu.Lock()
|
||||
existing := repo.data[repo.key(opts.Scope, keyHash)]
|
||||
require.NotNil(t, existing)
|
||||
existing.ExpiresAt = time.Now().Add(-time.Second)
|
||||
repo.mu.Unlock()
|
||||
|
||||
second, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, second)
|
||||
require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again")
|
||||
require.Equal(t, 2, execCount)
|
||||
|
||||
third, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, third)
|
||||
require.True(t, third.Replayed)
|
||||
payload, ok := third.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(2), payload["count"])
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2))
|
||||
require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-2",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-2",
|
||||
Payload: map[string]any{"a": 2},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict))
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ConflictTotal)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.FailedRetryBackoff = 2 * time.Second
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-3",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error")
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff))
|
||||
require.Greater(t, RetryAfterSecondsFromError(err), 0)
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2))
|
||||
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
|
||||
require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.concurrent",
|
||||
Method: "POST",
|
||||
Route: "/test/concurrent",
|
||||
ActorScope: "user:7",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "concurrent-case",
|
||||
Payload: map[string]any{"v": 1},
|
||||
}
|
||||
|
||||
var execCount int32
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
atomic.AddInt32(&execCount, 1)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
atomic.AddInt32(&execCount, 1)
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, replayed.Replayed)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once")
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ClaimTotal)
|
||||
require.Equal(t, uint64(1), metrics.ReplayTotal)
|
||||
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
|
||||
}
|
||||
|
||||
type failingIdempotencyRepo struct{}
|
||||
|
||||
func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig())
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.unavailable",
|
||||
Method: "POST",
|
||||
Route: "/test/unavailable",
|
||||
ActorScope: "admin:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-unavailable",
|
||||
Payload: map[string]any{"v": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1))
|
||||
}
|
||||
|
||||
func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
require.Nil(t, DefaultIdempotencyCoordinator())
|
||||
require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL())
|
||||
require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL())
|
||||
|
||||
coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
DefaultTTL: 2 * time.Hour,
|
||||
SystemOperationTTL: 15 * time.Minute,
|
||||
ProcessingTimeout: 10 * time.Second,
|
||||
FailedRetryBackoff: 3 * time.Second,
|
||||
ObserveOnly: false,
|
||||
})
|
||||
SetDefaultIdempotencyCoordinator(coordinator)
|
||||
t.Cleanup(func() {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
require.Same(t, coordinator, DefaultIdempotencyCoordinator())
|
||||
require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL())
|
||||
require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL())
|
||||
}
|
||||
|
||||
func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) {
|
||||
key, err := NormalizeIdempotencyKey(" abc-123 ")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "abc-123", key)
|
||||
|
||||
key, err = NormalizeIdempotencyKey("")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", key)
|
||||
|
||||
_, err = NormalizeIdempotencyKey(string(make([]byte, 129)))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = NormalizeIdempotencyKey("bad\nkey")
|
||||
require.Error(t, err)
|
||||
|
||||
fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, fp1)
|
||||
fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fp1, fp2)
|
||||
|
||||
_, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestRetryAfterSecondsFromErrorBranches(t *testing.T) {
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(nil))
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain")))
|
||||
|
||||
err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"})
|
||||
require.Equal(t, 12, RetryAfterSecondsFromError(err))
|
||||
|
||||
err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"})
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(err))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, nil)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err))
|
||||
|
||||
called := 0
|
||||
result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
RequireKey: true,
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
called++
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, called)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.Replayed)
|
||||
}
|
||||
|
||||
type noIDOwnerRepo struct{}
|
||||
|
||||
func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil }
|
||||
func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil }
|
||||
|
||||
func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) {
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(nil, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg)
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
IdempotencyKey: "k2",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err))
|
||||
|
||||
coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg)
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-no-id",
|
||||
IdempotencyKey: "k3",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
type conflictBranchRepo struct {
|
||||
existing *IdempotencyRecord
|
||||
tryReclaimErr error
|
||||
tryReclaimOK bool
|
||||
}
|
||||
|
||||
func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return cloneRecord(r.existing), nil
|
||||
}
|
||||
func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
if r.tryReclaimErr != nil {
|
||||
return false, r.tryReclaimErr
|
||||
}
|
||||
return r.tryReclaimOK, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) {
|
||||
now := time.Now()
|
||||
fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
badBody := "{bad-json"
|
||||
repo := &conflictBranchRepo{
|
||||
existing: &IdempotencyRecord{
|
||||
ID: 1,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: IdempotencyStatusSucceeded,
|
||||
ResponseBody: &badBody,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 2,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: "unknown",
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 3,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: IdempotencyStatusFailedRetryable,
|
||||
LockedUntil: ptrTime(now.Add(-time.Second)),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
repo.tryReclaimErr = errors.New("reclaim down")
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.tryReclaimErr = nil
|
||||
repo.tryReclaimOK = false
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
type markBehaviorRepo struct {
|
||||
inMemoryIdempotencyRepo
|
||||
failMarkSucceeded bool
|
||||
failMarkFailed bool
|
||||
}
|
||||
|
||||
func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
if r.failMarkSucceeded {
|
||||
return errors.New("mark succeeded failed")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt)
|
||||
}
|
||||
|
||||
func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
if r.failMarkFailed {
|
||||
return errors.New("mark failed retryable failed")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) {
|
||||
repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()}
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
repo.failMarkSucceeded = true
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-success",
|
||||
IdempotencyKey: "k1",
|
||||
Method: "POST",
|
||||
Route: "/ok",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.failMarkSucceeded = false
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-marshal",
|
||||
IdempotencyKey: "k2",
|
||||
Method: "POST",
|
||||
Route: "/bad",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"bad": make(chan int)}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.failMarkFailed = true
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-fail",
|
||||
IdempotencyKey: "k3",
|
||||
Method: "POST",
|
||||
Route: "/fail",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return nil, errors.New("plain failure")
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "plain failure", err.Error())
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_HelperBranches(t *testing.T) {
|
||||
c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
DefaultTTL: time.Hour,
|
||||
SystemOperationTTL: time.Hour,
|
||||
ProcessingTimeout: time.Second,
|
||||
FailedRetryBackoff: time.Second,
|
||||
MaxStoredResponseLen: 12,
|
||||
ObserveOnly: false,
|
||||
})
|
||||
|
||||
// conflictWithRetryAfter without locked_until should return base error.
|
||||
base := ErrIdempotencyInProgress
|
||||
err := c.conflictWithRetryAfter(base, nil, time.Now())
|
||||
require.Equal(t, infraerrors.Code(base), infraerrors.Code(err))
|
||||
|
||||
// marshalStoredResponse should truncate.
|
||||
body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, body, "...(truncated)")
|
||||
|
||||
// decodeStoredResponse empty and invalid json.
|
||||
out, err := c.decodeStoredResponse(nil)
|
||||
require.NoError(t, err)
|
||||
_, ok := out.(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
invalid := "{invalid"
|
||||
_, err = c.decodeStoredResponse(&invalid)
|
||||
require.Error(t, err)
|
||||
}
|
||||
389
backend/internal/service/subscription_assign_idempotency_test.go
Normal file
389
backend/internal/service/subscription_assign_idempotency_test.go
Normal file
@@ -0,0 +1,389 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type groupRepoNoop struct{}
|
||||
|
||||
func (groupRepoNoop) Create(context.Context, *Group) error { panic("unexpected Create call") }
|
||||
func (groupRepoNoop) GetByID(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
func (groupRepoNoop) GetByIDLite(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
func (groupRepoNoop) Update(context.Context, *Group) error { panic("unexpected Update call") }
|
||||
func (groupRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
|
||||
func (groupRepoNoop) DeleteCascade(context.Context, int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
func (groupRepoNoop) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
func (groupRepoNoop) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
func (groupRepoNoop) ListActive(context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
func (groupRepoNoop) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
func (groupRepoNoop) BindAccountsToGroup(context.Context, int64, []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
func (groupRepoNoop) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
|
||||
panic("unexpected UpdateSortOrders call")
|
||||
}
|
||||
|
||||
type subscriptionGroupRepoStub struct {
|
||||
groupRepoNoop
|
||||
group *Group
|
||||
}
|
||||
|
||||
func (s *subscriptionGroupRepoStub) GetByID(context.Context, int64) (*Group, error) {
|
||||
return s.group, nil
|
||||
}
|
||||
|
||||
type userSubRepoNoop struct{}
|
||||
|
||||
func (userSubRepoNoop) Create(context.Context, *UserSubscription) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
func (userSubRepoNoop) GetByID(context.Context, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
func (userSubRepoNoop) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetActiveByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) Update(context.Context, *UserSubscription) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
func (userSubRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
|
||||
func (userSubRepoNoop) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
panic("unexpected ListActiveByUserID call")
|
||||
}
|
||||
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
|
||||
panic("unexpected ExistsByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) ExtendExpiry(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ExtendExpiry call")
|
||||
}
|
||||
func (userSubRepoNoop) UpdateStatus(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateStatus call")
|
||||
}
|
||||
func (userSubRepoNoop) UpdateNotes(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateNotes call")
|
||||
}
|
||||
func (userSubRepoNoop) ActivateWindows(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ActivateWindows call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetDailyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetDailyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetWeeklyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetWeeklyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetMonthlyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetMonthlyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) IncrementUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected IncrementUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) BatchUpdateExpiredStatus(context.Context) (int64, error) {
|
||||
panic("unexpected BatchUpdateExpiredStatus call")
|
||||
}
|
||||
|
||||
type subscriptionUserSubRepoStub struct {
|
||||
userSubRepoNoop
|
||||
|
||||
nextID int64
|
||||
byID map[int64]*UserSubscription
|
||||
byUserGroup map[string]*UserSubscription
|
||||
createCalls int
|
||||
}
|
||||
|
||||
func newSubscriptionUserSubRepoStub() *subscriptionUserSubRepoStub {
|
||||
return &subscriptionUserSubRepoStub{
|
||||
nextID: 1,
|
||||
byID: make(map[int64]*UserSubscription),
|
||||
byUserGroup: make(map[string]*UserSubscription),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) key(userID, groupID int64) string {
|
||||
return strconvFormatInt(userID) + ":" + strconvFormatInt(groupID)
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) seed(sub *UserSubscription) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
cp := *sub
|
||||
if cp.ID == 0 {
|
||||
cp.ID = s.nextID
|
||||
s.nextID++
|
||||
}
|
||||
s.byID[cp.ID] = &cp
|
||||
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) ExistsByUserIDAndGroupID(_ context.Context, userID, groupID int64) (bool, error) {
|
||||
_, ok := s.byUserGroup[s.key(userID, groupID)]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) GetByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
sub := s.byUserGroup[s.key(userID, groupID)]
|
||||
if sub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) Create(_ context.Context, sub *UserSubscription) error {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
s.createCalls++
|
||||
cp := *sub
|
||||
if cp.ID == 0 {
|
||||
cp.ID = s.nextID
|
||||
s.nextID++
|
||||
}
|
||||
sub.ID = cp.ID
|
||||
s.byID[cp.ID] = &cp
|
||||
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
|
||||
sub := s.byID[id]
|
||||
if sub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 10,
|
||||
UserID: 1001,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "init",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 1001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "init",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), sub.ID)
|
||||
require.Equal(t, 0, subRepo.createCalls, "reuse should not create new subscription")
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionConflictWhenSemanticsMismatch(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 11,
|
||||
UserID: 2001,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "old-note",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 2001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "new-note",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_ASSIGN_CONFLICT", infraerrorsReason(err))
|
||||
require.Equal(t, 0, subRepo.createCalls, "conflict should not create or mutate existing subscription")
|
||||
}
|
||||
|
||||
func TestBulkAssignSubscriptionCreatedReusedAndConflict(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
// user 1: 语义一致,可 reused
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 21,
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "same-note",
|
||||
})
|
||||
// user 3: 语义冲突(有效期不一致),应 failed
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 23,
|
||||
UserID: 3,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 60),
|
||||
Notes: "same-note",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
result, err := svc.BulkAssignSubscription(context.Background(), &BulkAssignSubscriptionInput{
|
||||
UserIDs: []int64{1, 2, 3},
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
AssignedBy: 9,
|
||||
Notes: "same-note",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.SuccessCount)
|
||||
require.Equal(t, 1, result.CreatedCount)
|
||||
require.Equal(t, 1, result.ReusedCount)
|
||||
require.Equal(t, 1, result.FailedCount)
|
||||
require.Equal(t, "reused", result.Statuses[1])
|
||||
require.Equal(t, "created", result.Statuses[2])
|
||||
require.Equal(t, "failed", result.Statuses[3])
|
||||
require.Equal(t, 1, subRepo.createCalls)
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionKeepsWorkingWhenIdempotencyStoreUnavailable(t *testing.T) {
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
SetDefaultIdempotencyCoordinator(NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 9001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "new",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sub)
|
||||
require.Equal(t, 1, subRepo.createCalls, "semantic idempotent endpoint should not depend on idempotency store availability")
|
||||
}
|
||||
|
||||
func TestNormalizeAssignValidityDays(t *testing.T) {
|
||||
require.Equal(t, 30, normalizeAssignValidityDays(0))
|
||||
require.Equal(t, 30, normalizeAssignValidityDays(-5))
|
||||
require.Equal(t, MaxValidityDays, normalizeAssignValidityDays(MaxValidityDays+100))
|
||||
require.Equal(t, 7, normalizeAssignValidityDays(7))
|
||||
}
|
||||
|
||||
func TestDetectAssignSemanticConflictCases(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
base := &UserSubscription{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "same",
|
||||
}
|
||||
|
||||
reason, conflict := detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "same",
|
||||
})
|
||||
require.False(t, conflict)
|
||||
require.Equal(t, "", reason)
|
||||
|
||||
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 60,
|
||||
Notes: "same",
|
||||
})
|
||||
require.True(t, conflict)
|
||||
require.Equal(t, "validity_days_mismatch", reason)
|
||||
|
||||
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "other",
|
||||
})
|
||||
require.True(t, conflict)
|
||||
require.Equal(t, "notes_mismatch", reason)
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionGroupTypeValidation(t *testing.T) {
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeStandard},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
|
||||
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrGroupNotSubscriptionType), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func strconvFormatInt(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
func infraerrorsReason(err error) string {
|
||||
return infraerrors.Reason(err)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -24,16 +25,17 @@ var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
|
||||
const MaxValidityDays = 36500
|
||||
|
||||
var (
|
||||
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
|
||||
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
|
||||
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
|
||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
|
||||
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
|
||||
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
|
||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
@@ -150,40 +152,10 @@ type AssignSubscriptionInput struct {
|
||||
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
sub, _, err := s.assignSubscriptionWithReuse(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrSubscriptionAlreadyExists
|
||||
}
|
||||
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
@@ -363,9 +335,12 @@ type BulkAssignSubscriptionInput struct {
|
||||
// BulkAssignResult 批量分配结果
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int
|
||||
CreatedCount int
|
||||
ReusedCount int
|
||||
FailedCount int
|
||||
Subscriptions []UserSubscription
|
||||
Errors []string
|
||||
Statuses map[int64]string
|
||||
}
|
||||
|
||||
// BulkAssignSubscription 批量分配订阅
|
||||
@@ -373,10 +348,11 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
result := &BulkAssignResult{
|
||||
Subscriptions: make([]UserSubscription, 0),
|
||||
Errors: make([]string, 0),
|
||||
Statuses: make(map[int64]string),
|
||||
}
|
||||
|
||||
for _, userID := range input.UserIDs {
|
||||
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
|
||||
sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: input.GroupID,
|
||||
ValidityDays: input.ValidityDays,
|
||||
@@ -386,15 +362,105 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
if err != nil {
|
||||
result.FailedCount++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
|
||||
result.Statuses[userID] = "failed"
|
||||
} else {
|
||||
result.SuccessCount++
|
||||
result.Subscriptions = append(result.Subscriptions, *sub)
|
||||
if reused {
|
||||
result.ReusedCount++
|
||||
result.Statuses[userID] = "reused"
|
||||
} else {
|
||||
result.CreatedCount++
|
||||
result.Statuses[userID] = "created"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, false, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if exists {
|
||||
sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if getErr != nil {
|
||||
return nil, false, getErr
|
||||
}
|
||||
if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict {
|
||||
return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{
|
||||
"conflict_reason": conflictReason,
|
||||
})
|
||||
}
|
||||
return sub, true, nil
|
||||
}
|
||||
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, false, nil
|
||||
}
|
||||
|
||||
func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) {
|
||||
if existing == nil || input == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalizedDays := normalizeAssignValidityDays(input.ValidityDays)
|
||||
if !existing.StartsAt.IsZero() {
|
||||
expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays)
|
||||
if expectedExpiresAt.After(MaxExpiresAt) {
|
||||
expectedExpiresAt = MaxExpiresAt
|
||||
}
|
||||
if !existing.ExpiresAt.Equal(expectedExpiresAt) {
|
||||
return "validity_days_mismatch", true
|
||||
}
|
||||
}
|
||||
|
||||
existingNotes := strings.TrimSpace(existing.Notes)
|
||||
inputNotes := strings.TrimSpace(input.Notes)
|
||||
if existingNotes != inputNotes {
|
||||
return "notes_mismatch", true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func normalizeAssignValidityDays(days int) int {
|
||||
if days <= 0 {
|
||||
days = 30
|
||||
}
|
||||
if days > MaxValidityDays {
|
||||
days = MaxValidityDays
|
||||
}
|
||||
return days
|
||||
}
|
||||
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
|
||||
214
backend/internal/service/system_operation_lock_service.go
Normal file
214
backend/internal/service/system_operation_lock_service.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
systemOperationLockScope = "admin.system.operations.global_lock"
|
||||
systemOperationLockKey = "global-system-operation-lock"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSystemOperationBusy = infraerrors.Conflict("SYSTEM_OPERATION_BUSY", "another system operation is in progress")
|
||||
)
|
||||
|
||||
type SystemOperationLock struct {
|
||||
recordID int64
|
||||
operationID string
|
||||
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (l *SystemOperationLock) OperationID() string {
|
||||
if l == nil {
|
||||
return ""
|
||||
}
|
||||
return l.operationID
|
||||
}
|
||||
|
||||
type SystemOperationLockService struct {
|
||||
repo IdempotencyRepository
|
||||
|
||||
lease time.Duration
|
||||
renewInterval time.Duration
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewSystemOperationLockService(repo IdempotencyRepository, cfg IdempotencyConfig) *SystemOperationLockService {
|
||||
lease := cfg.ProcessingTimeout
|
||||
if lease <= 0 {
|
||||
lease = 30 * time.Second
|
||||
}
|
||||
renewInterval := lease / 3
|
||||
if renewInterval < time.Second {
|
||||
renewInterval = time.Second
|
||||
}
|
||||
ttl := cfg.SystemOperationTTL
|
||||
if ttl <= 0 {
|
||||
ttl = time.Hour
|
||||
}
|
||||
|
||||
return &SystemOperationLockService{
|
||||
repo: repo,
|
||||
lease: lease,
|
||||
renewInterval: renewInterval,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) Acquire(ctx context.Context, operationID string) (*SystemOperationLock, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if operationID == "" {
|
||||
return nil, infraerrors.BadRequest("SYSTEM_OPERATION_ID_REQUIRED", "operation id is required")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(s.ttl)
|
||||
lockedUntil := now.Add(s.lease)
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
|
||||
record := &IdempotencyRecord{
|
||||
Scope: systemOperationLockScope,
|
||||
IdempotencyKeyHash: keyHash,
|
||||
RequestFingerprint: operationID,
|
||||
Status: IdempotencyStatusProcessing,
|
||||
LockedUntil: &lockedUntil,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
owner, err := s.repo.CreateProcessing(ctx, record)
|
||||
if err != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
|
||||
}
|
||||
if !owner {
|
||||
existing, getErr := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
|
||||
if getErr != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if existing.Status == IdempotencyStatusProcessing && existing.LockedUntil != nil && existing.LockedUntil.After(now) {
|
||||
return nil, s.busyError(existing.RequestFingerprint, existing.LockedUntil, now)
|
||||
}
|
||||
reclaimed, reclaimErr := s.repo.TryReclaim(
|
||||
ctx,
|
||||
existing.ID,
|
||||
existing.Status,
|
||||
now,
|
||||
lockedUntil,
|
||||
expiresAt,
|
||||
)
|
||||
if reclaimErr != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if !reclaimed {
|
||||
latest, _ := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
|
||||
if latest != nil {
|
||||
return nil, s.busyError(latest.RequestFingerprint, latest.LockedUntil, now)
|
||||
}
|
||||
return nil, ErrSystemOperationBusy
|
||||
}
|
||||
record.ID = existing.ID
|
||||
}
|
||||
|
||||
if record.ID == 0 {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
lock := &SystemOperationLock{
|
||||
recordID: record.ID,
|
||||
operationID: operationID,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go s.renewLoop(lock)
|
||||
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) Release(ctx context.Context, lock *SystemOperationLock, succeeded bool, failureReason string) error {
|
||||
if s == nil || s.repo == nil || lock == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lock.stopOnce.Do(func() {
|
||||
close(lock.stopCh)
|
||||
})
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(s.ttl)
|
||||
if succeeded {
|
||||
responseBody := fmt.Sprintf(`{"operation_id":"%s","released":true}`, lock.operationID)
|
||||
return s.repo.MarkSucceeded(ctx, lock.recordID, 200, responseBody, expiresAt)
|
||||
}
|
||||
|
||||
reason := failureReason
|
||||
if reason == "" {
|
||||
reason = "SYSTEM_OPERATION_FAILED"
|
||||
}
|
||||
return s.repo.MarkFailedRetryable(ctx, lock.recordID, reason, time.Now(), expiresAt)
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) renewLoop(lock *SystemOperationLock) {
|
||||
ticker := time.NewTicker(s.renewInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ok, err := s.repo.ExtendProcessingLock(
|
||||
ctx,
|
||||
lock.recordID,
|
||||
lock.operationID,
|
||||
now.Add(s.lease),
|
||||
now.Add(s.ttl),
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew failed operation_id=%s err=%v", lock.operationID, err)
|
||||
// 瞬时故障不应导致续租协程退出,下一轮继续尝试续租。
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew stopped operation_id=%s reason=ownership_lost", lock.operationID)
|
||||
return
|
||||
}
|
||||
case <-lock.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) busyError(operationID string, lockedUntil *time.Time, now time.Time) error {
|
||||
metadata := make(map[string]string)
|
||||
if operationID != "" {
|
||||
metadata["operation_id"] = operationID
|
||||
}
|
||||
if lockedUntil != nil {
|
||||
sec := int(lockedUntil.Sub(now).Seconds())
|
||||
if sec <= 0 {
|
||||
sec = 1
|
||||
}
|
||||
metadata["retry_after"] = strconv.Itoa(sec)
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return ErrSystemOperationBusy
|
||||
}
|
||||
return ErrSystemOperationBusy.WithMetadata(metadata)
|
||||
}
|
||||
305
backend/internal/service/system_operation_lock_service_test.go
Normal file
305
backend/internal/service/system_operation_lock_service_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSystemOperationLockService_AcquireBusyAndRelease(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
_, err = svc.Acquire(context.Background(), "op-2")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Equal(t, "op-1", appErr.Metadata["operation_id"])
|
||||
require.NotEmpty(t, appErr.Metadata["retry_after"])
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RenewLease(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 1200 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock, err := svc.Acquire(context.Background(), "op-renew")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
defer func() {
|
||||
_ = svc.Release(context.Background(), lock, true, "")
|
||||
}()
|
||||
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, initial)
|
||||
require.NotNil(t, initial.LockedUntil)
|
||||
initialLockedUntil := *initial.LockedUntil
|
||||
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, updated)
|
||||
require.NotNil(t, updated.LockedUntil)
|
||||
require.True(t, updated.LockedUntil.After(initialLockedUntil), "locked_until should be renewed while lock is held")
|
||||
}
|
||||
|
||||
type flakySystemLockRenewRepo struct {
|
||||
*inMemoryIdempotencyRepo
|
||||
extendCalls int32
|
||||
}
|
||||
|
||||
func (r *flakySystemLockRenewRepo) ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
call := atomic.AddInt32(&r.extendCalls, 1)
|
||||
if call == 1 {
|
||||
return false, errors.New("transient extend failure")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.ExtendProcessingLock(ctx, id, requestFingerprint, newLockedUntil, newExpiresAt)
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RenewLeaseContinuesAfterTransientFailure(t *testing.T) {
|
||||
repo := &flakySystemLockRenewRepo{inMemoryIdempotencyRepo: newInMemoryIdempotencyRepo()}
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 2400 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock, err := svc.Acquire(context.Background(), "op-renew-transient")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
defer func() {
|
||||
_ = svc.Release(context.Background(), lock, true, "")
|
||||
}()
|
||||
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, initial)
|
||||
require.NotNil(t, initial.LockedUntil)
|
||||
initialLockedUntil := *initial.LockedUntil
|
||||
|
||||
// 首次续租失败后,下一轮应继续尝试并成功更新锁过期时间。
|
||||
require.Eventually(t, func() bool {
|
||||
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
if updated == nil || updated.LockedUntil == nil {
|
||||
return false
|
||||
}
|
||||
return atomic.LoadInt32(&repo.extendCalls) >= 2 && updated.LockedUntil.After(initialLockedUntil)
|
||||
}, 4*time.Second, 100*time.Millisecond, "renew loop should continue after transient error")
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_SameOperationIDRetryWhileRunning(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-same")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
_, err = svc.Acquire(context.Background(), "op-same")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Equal(t, "op-same", appErr.Metadata["operation_id"])
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-same")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RecoverAfterLeaseExpired(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 300 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-crashed")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
// 模拟实例异常:停止续租,不调用 Release。
|
||||
lock1.stopOnce.Do(func() {
|
||||
close(lock1.stopCh)
|
||||
})
|
||||
|
||||
time.Sleep(450 * time.Millisecond)
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-recovered")
|
||||
require.NoError(t, err, "expired lease should allow a new operation to reclaim lock")
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
type systemLockRepoStub struct {
|
||||
createOwner bool
|
||||
createErr error
|
||||
existing *IdempotencyRecord
|
||||
getErr error
|
||||
reclaimOK bool
|
||||
reclaimErr error
|
||||
markSuccErr error
|
||||
markFailErr error
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
if s.createErr != nil {
|
||||
return false, s.createErr
|
||||
}
|
||||
return s.createOwner, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return cloneRecord(s.existing), nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
if s.reclaimErr != nil {
|
||||
return false, s.reclaimErr
|
||||
}
|
||||
return s.reclaimOK, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return s.markSuccErr
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return s.markFailErr
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_InputAndStoreErrorBranches(t *testing.T) {
|
||||
var nilSvc *SystemOperationLockService
|
||||
_, err := nilSvc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
svc := &SystemOperationLockService{repo: nil}
|
||||
_, err = svc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
svc = NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
_, err = svc.Acquire(context.Background(), "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SYSTEM_OPERATION_ID_REQUIRED", infraerrors.Reason(err))
|
||||
|
||||
badStore := &systemLockRepoStub{createErr: errors.New("db down")}
|
||||
svc = NewSystemOperationLockService(badStore, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
_, err = svc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_ExistingNilAndReclaimBranches(t *testing.T) {
|
||||
now := time.Now()
|
||||
repo := &systemLockRepoStub{
|
||||
createOwner: false,
|
||||
}
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
_, err := svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 1,
|
||||
Scope: systemOperationLockScope,
|
||||
IdempotencyKeyHash: HashIdempotencyKey(systemOperationLockKey),
|
||||
RequestFingerprint: "other-op",
|
||||
Status: IdempotencyStatusFailedRetryable,
|
||||
LockedUntil: ptrTime(now.Add(-time.Second)),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
repo.reclaimErr = errors.New("reclaim failed")
|
||||
_, err = svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.reclaimErr = nil
|
||||
repo.reclaimOK = false
|
||||
_, err = svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_ReleaseBranchesAndOperationID(t *testing.T) {
|
||||
require.Equal(t, "", (*SystemOperationLock)(nil).OperationID())
|
||||
|
||||
svc := NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
lock, err := svc.Acquire(context.Background(), "op")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock, false, ""))
|
||||
require.NoError(t, svc.Release(context.Background(), lock, true, ""))
|
||||
|
||||
repo := &systemLockRepoStub{
|
||||
createOwner: true,
|
||||
markSuccErr: errors.New("mark succeeded failed"),
|
||||
markFailErr: errors.New("mark failed failed"),
|
||||
}
|
||||
svc = NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
lock = &SystemOperationLock{recordID: 1, operationID: "op2", stopCh: make(chan struct{})}
|
||||
require.Error(t, svc.Release(context.Background(), lock, true, ""))
|
||||
lock = &SystemOperationLock{recordID: 1, operationID: "op3", stopCh: make(chan struct{})}
|
||||
require.Error(t, svc.Release(context.Background(), lock, false, "BAD"))
|
||||
|
||||
var nilLockSvc *SystemOperationLockService
|
||||
require.NoError(t, nilLockSvc.Release(context.Background(), nil, true, ""))
|
||||
|
||||
err = svc.busyError("", nil, time.Now())
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
}
|
||||
@@ -320,6 +320,10 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
return err
|
||||
}
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
if status == UsageCleanupStatusCanceled {
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent hit: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
@@ -329,6 +333,11 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
}
|
||||
if !ok {
|
||||
// 状态可能并发改变
|
||||
currentStatus, getErr := s.repo.GetTaskStatus(ctx, taskID)
|
||||
if getErr == nil && currentStatus == UsageCleanupStatusCanceled {
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent race hit: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
|
||||
@@ -644,6 +644,23 @@ func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskAlreadyCanceledIsIdempotent(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusCanceled,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.cancelCalls, "already canceled should return success without extra cancel write")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
|
||||
shouldCancel := false
|
||||
repo := &cleanupRepoStub{
|
||||
|
||||
@@ -225,6 +225,45 @@ func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Confi
|
||||
return svc
|
||||
}
|
||||
|
||||
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
|
||||
idempotencyCfg := DefaultIdempotencyConfig()
|
||||
if cfg != nil {
|
||||
if cfg.Idempotency.DefaultTTLSeconds > 0 {
|
||||
idempotencyCfg.DefaultTTL = time.Duration(cfg.Idempotency.DefaultTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.SystemOperationTTLSeconds > 0 {
|
||||
idempotencyCfg.SystemOperationTTL = time.Duration(cfg.Idempotency.SystemOperationTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.ProcessingTimeoutSeconds > 0 {
|
||||
idempotencyCfg.ProcessingTimeout = time.Duration(cfg.Idempotency.ProcessingTimeoutSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.FailedRetryBackoffSeconds > 0 {
|
||||
idempotencyCfg.FailedRetryBackoff = time.Duration(cfg.Idempotency.FailedRetryBackoffSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.MaxStoredResponseLen > 0 {
|
||||
idempotencyCfg.MaxStoredResponseLen = cfg.Idempotency.MaxStoredResponseLen
|
||||
}
|
||||
idempotencyCfg.ObserveOnly = cfg.Idempotency.ObserveOnly
|
||||
}
|
||||
return idempotencyCfg
|
||||
}
|
||||
|
||||
func ProvideIdempotencyCoordinator(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCoordinator {
|
||||
coordinator := NewIdempotencyCoordinator(repo, buildIdempotencyConfig(cfg))
|
||||
SetDefaultIdempotencyCoordinator(coordinator)
|
||||
return coordinator
|
||||
}
|
||||
|
||||
func ProvideSystemOperationLockService(repo IdempotencyRepository, cfg *config.Config) *SystemOperationLockService {
|
||||
return NewSystemOperationLockService(repo, buildIdempotencyConfig(cfg))
|
||||
}
|
||||
|
||||
func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
|
||||
svc := NewIdempotencyCleanupService(repo, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
||||
func ProvideOpsScheduledReportService(
|
||||
opsService *OpsService,
|
||||
@@ -318,4 +357,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
NewDigestSessionStore,
|
||||
ProvideIdempotencyCoordinator,
|
||||
ProvideSystemOperationLockService,
|
||||
ProvideIdempotencyCleanupService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user