feat(idempotency): 为关键写接口接入幂等并完善并发容错
This commit is contained in:
@@ -35,6 +35,8 @@ var (
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyLastUsedMinTouch = 30 * time.Second
|
||||
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
|
||||
apiKeyLastUsedFailBackoff = 5 * time.Second
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
@@ -129,7 +131,7 @@ type APIKeyService struct {
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> time.Time
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
@@ -574,7 +576,7 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
|
||||
now := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -582,15 +584,16 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
|
||||
latest := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
|
||||
return nil, fmt.Errorf("touch api key last used: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Store(keyID, latest)
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
|
||||
return nil, nil
|
||||
})
|
||||
return err
|
||||
|
||||
@@ -79,8 +79,27 @@ func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
|
||||
require.ErrorContains(t, err, "touch api key last used")
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
|
||||
_, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.False(t, ok, "failed touch should not update debounce cache")
|
||||
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.True(t, ok, "failed touch should still update retry debounce cache")
|
||||
_, isTime := cached.(time.Time)
|
||||
require.True(t, isTime)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("db write failed")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
firstErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.Error(t, firstErr)
|
||||
require.ErrorContains(t, firstErr, "touch api key last used")
|
||||
|
||||
secondErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
|
||||
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
|
||||
}
|
||||
|
||||
type touchSingleflightRepo struct {
|
||||
|
||||
471
backend/internal/service/idempotency.go
Normal file
471
backend/internal/service/idempotency.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const (
|
||||
IdempotencyStatusProcessing = "processing"
|
||||
IdempotencyStatusSucceeded = "succeeded"
|
||||
IdempotencyStatusFailedRetryable = "failed_retryable"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required")
|
||||
ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid")
|
||||
ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload")
|
||||
ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing")
|
||||
ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window")
|
||||
ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable")
|
||||
ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload")
|
||||
)
|
||||
|
||||
type IdempotencyRecord struct {
|
||||
ID int64
|
||||
Scope string
|
||||
IdempotencyKeyHash string
|
||||
RequestFingerprint string
|
||||
Status string
|
||||
ResponseStatus *int
|
||||
ResponseBody *string
|
||||
ErrorReason *string
|
||||
LockedUntil *time.Time
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type IdempotencyRepository interface {
|
||||
CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error)
|
||||
GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error)
|
||||
TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error)
|
||||
ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error)
|
||||
MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error
|
||||
MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error
|
||||
DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error)
|
||||
}
|
||||
|
||||
type IdempotencyConfig struct {
|
||||
DefaultTTL time.Duration
|
||||
SystemOperationTTL time.Duration
|
||||
ProcessingTimeout time.Duration
|
||||
FailedRetryBackoff time.Duration
|
||||
MaxStoredResponseLen int
|
||||
ObserveOnly bool
|
||||
}
|
||||
|
||||
func DefaultIdempotencyConfig() IdempotencyConfig {
|
||||
return IdempotencyConfig{
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
SystemOperationTTL: 1 * time.Hour,
|
||||
ProcessingTimeout: 30 * time.Second,
|
||||
FailedRetryBackoff: 5 * time.Second,
|
||||
MaxStoredResponseLen: 64 * 1024,
|
||||
ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断
|
||||
}
|
||||
}
|
||||
|
||||
type IdempotencyExecuteOptions struct {
|
||||
Scope string
|
||||
ActorScope string
|
||||
Method string
|
||||
Route string
|
||||
IdempotencyKey string
|
||||
Payload any
|
||||
TTL time.Duration
|
||||
RequireKey bool
|
||||
}
|
||||
|
||||
type IdempotencyExecuteResult struct {
|
||||
Data any
|
||||
Replayed bool
|
||||
}
|
||||
|
||||
type IdempotencyCoordinator struct {
|
||||
repo IdempotencyRepository
|
||||
cfg IdempotencyConfig
|
||||
}
|
||||
|
||||
var (
|
||||
defaultIdempotencyMu sync.RWMutex
|
||||
defaultIdempotencySvc *IdempotencyCoordinator
|
||||
)
|
||||
|
||||
func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) {
|
||||
defaultIdempotencyMu.Lock()
|
||||
defaultIdempotencySvc = svc
|
||||
defaultIdempotencyMu.Unlock()
|
||||
}
|
||||
|
||||
func DefaultIdempotencyCoordinator() *IdempotencyCoordinator {
|
||||
defaultIdempotencyMu.RLock()
|
||||
defer defaultIdempotencyMu.RUnlock()
|
||||
return defaultIdempotencySvc
|
||||
}
|
||||
|
||||
func DefaultWriteIdempotencyTTL() time.Duration {
|
||||
defaultTTL := DefaultIdempotencyConfig().DefaultTTL
|
||||
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 {
|
||||
return coordinator.cfg.DefaultTTL
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
func DefaultSystemOperationIdempotencyTTL() time.Duration {
|
||||
defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL
|
||||
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 {
|
||||
return coordinator.cfg.SystemOperationTTL
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator {
|
||||
return &IdempotencyCoordinator{
|
||||
repo: repo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeIdempotencyKey(raw string) (string, error) {
|
||||
key := strings.TrimSpace(raw)
|
||||
if key == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(key) > 128 {
|
||||
return "", ErrIdempotencyKeyInvalid
|
||||
}
|
||||
for _, r := range key {
|
||||
if r < 33 || r > 126 {
|
||||
return "", ErrIdempotencyKeyInvalid
|
||||
}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func HashIdempotencyKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) {
|
||||
if method == "" {
|
||||
method = "POST"
|
||||
}
|
||||
if route == "" {
|
||||
route = "/"
|
||||
}
|
||||
if actorScope == "" {
|
||||
actorScope = "anonymous"
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", ErrIdempotencyInvalidPayload.WithCause(err)
|
||||
}
|
||||
sum := sha256.Sum256([]byte(
|
||||
strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw),
|
||||
))
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func RetryAfterSecondsFromError(err error) int {
|
||||
appErr := new(infraerrors.ApplicationError)
|
||||
if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil {
|
||||
return 0
|
||||
}
|
||||
v := strings.TrimSpace(appErr.Metadata["retry_after"])
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
seconds, convErr := strconv.Atoi(v)
|
||||
if convErr != nil || seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return seconds
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) Execute(
|
||||
ctx context.Context,
|
||||
opts IdempotencyExecuteOptions,
|
||||
execute func(context.Context) (any, error),
|
||||
) (*IdempotencyExecuteResult, error) {
|
||||
if execute == nil {
|
||||
return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil")
|
||||
}
|
||||
|
||||
key, err := NormalizeIdempotencyKey(opts.IdempotencyKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key == "" {
|
||||
if opts.RequireKey && !c.cfg.ObserveOnly {
|
||||
return nil, ErrIdempotencyKeyRequired
|
||||
}
|
||||
data, execErr := execute(ctx)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return &IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
if c.repo == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil")
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
if opts.Scope == "" {
|
||||
return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required")
|
||||
}
|
||||
|
||||
fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl := opts.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = c.cfg.DefaultTTL
|
||||
}
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(ttl)
|
||||
lockedUntil := now.Add(c.cfg.ProcessingTimeout)
|
||||
keyHash := HashIdempotencyKey(key)
|
||||
|
||||
record := &IdempotencyRecord{
|
||||
Scope: opts.Scope,
|
||||
IdempotencyKeyHash: keyHash,
|
||||
RequestFingerprint: fingerprint,
|
||||
Status: IdempotencyStatusProcessing,
|
||||
LockedUntil: &lockedUntil,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
owner, err := c.repo.CreateProcessing(ctx, record)
|
||||
if err != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "create_processing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
|
||||
}
|
||||
if owner {
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{
|
||||
"claim_mode": "new",
|
||||
})
|
||||
}
|
||||
if !owner {
|
||||
existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
|
||||
if getErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "get_existing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
|
||||
}
|
||||
if existing == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "missing_existing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if existing.RequestFingerprint != fingerprint {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
reclaimedByExpired := false
|
||||
if !existing.ExpiresAt.After(now) {
|
||||
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt)
|
||||
if reclaimErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{
|
||||
"operation": "try_reclaim_expired",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if taken {
|
||||
reclaimedByExpired = true
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{
|
||||
"claim_mode": "expired_reclaim",
|
||||
})
|
||||
record.ID = existing.ID
|
||||
} else {
|
||||
latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
|
||||
if latestErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "get_existing_after_expired_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr)
|
||||
}
|
||||
if latest == nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
|
||||
"operation": "missing_existing_after_expired_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if latest.RequestFingerprint != fingerprint {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
existing = latest
|
||||
}
|
||||
}
|
||||
|
||||
if !reclaimedByExpired {
|
||||
switch existing.Status {
|
||||
case IdempotencyStatusSucceeded:
|
||||
data, parseErr := c.decodeStoredResponse(existing.ResponseBody)
|
||||
if parseErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{
|
||||
"operation": "decode_stored_response",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr)
|
||||
}
|
||||
recordIdempotencyReplay(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil)
|
||||
return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil
|
||||
case IdempotencyStatusProcessing:
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil)
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
|
||||
case IdempotencyStatusFailedRetryable:
|
||||
if existing.LockedUntil != nil && existing.LockedUntil.After(now) {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"})
|
||||
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil)
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now)
|
||||
}
|
||||
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt)
|
||||
if reclaimErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{
|
||||
"operation": "try_reclaim",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if !taken {
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{
|
||||
"conflict": "reclaim_race",
|
||||
})
|
||||
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
|
||||
}
|
||||
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{
|
||||
"claim_mode": "reclaim",
|
||||
})
|
||||
record.ID = existing.ID
|
||||
default:
|
||||
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"})
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{
|
||||
"status": existing.Status,
|
||||
})
|
||||
return nil, ErrIdempotencyKeyConflict
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if record.ID == 0 {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "record_id_missing",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
execStart := time.Now()
|
||||
defer func() {
|
||||
recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil)
|
||||
}()
|
||||
|
||||
data, execErr := execute(ctx)
|
||||
if execErr != nil {
|
||||
backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff)
|
||||
reason := infraerrors.Reason(execErr)
|
||||
if reason == "" {
|
||||
reason = "EXECUTION_FAILED"
|
||||
}
|
||||
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{
|
||||
"reason": reason,
|
||||
})
|
||||
if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "mark_failed_retryable",
|
||||
})
|
||||
}
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
storedBody, marshalErr := c.marshalStoredResponse(data)
|
||||
if marshalErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "marshal_response",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr)
|
||||
}
|
||||
if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil {
|
||||
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error")
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
|
||||
"operation": "mark_succeeded",
|
||||
})
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(markErr)
|
||||
}
|
||||
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil)
|
||||
|
||||
return &IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error {
|
||||
if lockedUntil == nil {
|
||||
return base
|
||||
}
|
||||
sec := int(lockedUntil.Sub(now).Seconds())
|
||||
if sec <= 0 {
|
||||
sec = 1
|
||||
}
|
||||
return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)})
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) {
|
||||
raw, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
redacted := logredact.RedactText(string(raw))
|
||||
if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen {
|
||||
redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)"
|
||||
}
|
||||
return redacted, nil
|
||||
}
|
||||
|
||||
func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) {
|
||||
if stored == nil || strings.TrimSpace(*stored) == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(*stored), &out); err != nil {
|
||||
return nil, fmt.Errorf("decode stored response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
91
backend/internal/service/idempotency_cleanup_service.go
Normal file
91
backend/internal/service/idempotency_cleanup_service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。
|
||||
type IdempotencyCleanupService struct {
|
||||
repo IdempotencyRepository
|
||||
interval time.Duration
|
||||
batch int
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
|
||||
interval := 60 * time.Second
|
||||
batch := 500
|
||||
if cfg != nil {
|
||||
if cfg.Idempotency.CleanupIntervalSeconds > 0 {
|
||||
interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.CleanupBatchSize > 0 {
|
||||
batch = cfg.Idempotency.CleanupBatchSize
|
||||
}
|
||||
}
|
||||
return &IdempotencyCleanupService{
|
||||
repo: repo,
|
||||
interval: interval,
|
||||
batch: batch,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) Start() {
|
||||
if s == nil || s.repo == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch)
|
||||
go s.runLoop()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) runLoop() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动后先清理一轮,防止重启后积压。
|
||||
s.cleanupOnce()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.cleanupOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IdempotencyCleanupService) cleanupOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err)
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted)
|
||||
}
|
||||
}
|
||||
69
backend/internal/service/idempotency_cleanup_service_test.go
Normal file
69
backend/internal/service/idempotency_cleanup_service_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type idempotencyCleanupRepoStub struct {
|
||||
deleteCalls int
|
||||
lastLimit int
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) {
|
||||
r.deleteCalls++
|
||||
r.lastLimit = limit
|
||||
if r.deleteErr != nil {
|
||||
return 0, r.deleteErr
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) {
|
||||
repo := &idempotencyCleanupRepoStub{}
|
||||
cfg := &config.Config{
|
||||
Idempotency: config.IdempotencyConfig{
|
||||
CleanupIntervalSeconds: 7,
|
||||
CleanupBatchSize: 321,
|
||||
},
|
||||
}
|
||||
svc := NewIdempotencyCleanupService(repo, cfg)
|
||||
require.Equal(t, 7*time.Second, svc.interval)
|
||||
require.Equal(t, 321, svc.batch)
|
||||
}
|
||||
|
||||
func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) {
|
||||
repo := &idempotencyCleanupRepoStub{}
|
||||
svc := NewIdempotencyCleanupService(repo, &config.Config{
|
||||
Idempotency: config.IdempotencyConfig{
|
||||
CleanupBatchSize: 99,
|
||||
},
|
||||
})
|
||||
|
||||
svc.cleanupOnce()
|
||||
require.Equal(t, 1, repo.deleteCalls)
|
||||
require.Equal(t, 99, repo.lastLimit)
|
||||
}
|
||||
171
backend/internal/service/idempotency_observability.go
Normal file
171
backend/internal/service/idempotency_observability.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。
|
||||
type IdempotencyMetricsSnapshot struct {
|
||||
ClaimTotal uint64 `json:"claim_total"`
|
||||
ReplayTotal uint64 `json:"replay_total"`
|
||||
ConflictTotal uint64 `json:"conflict_total"`
|
||||
RetryBackoffTotal uint64 `json:"retry_backoff_total"`
|
||||
ProcessingDurationCount uint64 `json:"processing_duration_count"`
|
||||
ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"`
|
||||
StoreUnavailableTotal uint64 `json:"store_unavailable_total"`
|
||||
}
|
||||
|
||||
type idempotencyMetrics struct {
|
||||
claimTotal atomic.Uint64
|
||||
replayTotal atomic.Uint64
|
||||
conflictTotal atomic.Uint64
|
||||
retryBackoffTotal atomic.Uint64
|
||||
processingDurationCount atomic.Uint64
|
||||
processingDurationMicros atomic.Uint64
|
||||
storeUnavailableTotal atomic.Uint64
|
||||
}
|
||||
|
||||
var defaultIdempotencyMetrics idempotencyMetrics
|
||||
|
||||
// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。
|
||||
func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot {
|
||||
totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load()
|
||||
return IdempotencyMetricsSnapshot{
|
||||
ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(),
|
||||
ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(),
|
||||
ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(),
|
||||
RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(),
|
||||
ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(),
|
||||
ProcessingDurationTotalMs: float64(totalMicros) / 1000.0,
|
||||
StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.claimTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.replayTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.conflictTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) {
|
||||
defaultIdempotencyMetrics.retryBackoffTotal.Add(1)
|
||||
logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) {
|
||||
if duration < 0 {
|
||||
duration = 0
|
||||
}
|
||||
defaultIdempotencyMetrics.processingDurationCount.Add(1)
|
||||
defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds()))
|
||||
logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs)
|
||||
}
|
||||
|
||||
// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。
|
||||
func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) {
|
||||
defaultIdempotencyMetrics.storeUnavailableTotal.Add(1)
|
||||
attrs := map[string]string{}
|
||||
if strategy != "" {
|
||||
attrs["strategy"] = strategy
|
||||
}
|
||||
logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs)
|
||||
}
|
||||
|
||||
func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) {
|
||||
var b strings.Builder
|
||||
builderWriteString(&b, "[IdempotencyAudit]")
|
||||
builderWriteString(&b, " endpoint=")
|
||||
builderWriteString(&b, safeAuditField(endpoint))
|
||||
builderWriteString(&b, " scope=")
|
||||
builderWriteString(&b, safeAuditField(scope))
|
||||
builderWriteString(&b, " key_hash=")
|
||||
builderWriteString(&b, safeAuditField(keyHash))
|
||||
builderWriteString(&b, " state_transition=")
|
||||
builderWriteString(&b, safeAuditField(stateTransition))
|
||||
builderWriteString(&b, " replayed=")
|
||||
builderWriteString(&b, strconv.FormatBool(replayed))
|
||||
if len(attrs) > 0 {
|
||||
appendSortedAttrs(&b, attrs)
|
||||
}
|
||||
logger.LegacyPrintf("service.idempotency", "%s", b.String())
|
||||
}
|
||||
|
||||
func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) {
|
||||
var b strings.Builder
|
||||
builderWriteString(&b, "[IdempotencyMetric]")
|
||||
builderWriteString(&b, " name=")
|
||||
builderWriteString(&b, safeAuditField(name))
|
||||
builderWriteString(&b, " endpoint=")
|
||||
builderWriteString(&b, safeAuditField(endpoint))
|
||||
builderWriteString(&b, " scope=")
|
||||
builderWriteString(&b, safeAuditField(scope))
|
||||
builderWriteString(&b, " value=")
|
||||
builderWriteString(&b, safeAuditField(value))
|
||||
if len(attrs) > 0 {
|
||||
appendSortedAttrs(&b, attrs)
|
||||
}
|
||||
logger.LegacyPrintf("service.idempotency", "%s", b.String())
|
||||
}
|
||||
|
||||
func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) {
|
||||
if len(attrs) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(attrs))
|
||||
for k := range attrs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
builderWriteByte(builder, ' ')
|
||||
builderWriteString(builder, k)
|
||||
builderWriteByte(builder, '=')
|
||||
builderWriteString(builder, safeAuditField(attrs[k]))
|
||||
}
|
||||
}
|
||||
|
||||
func safeAuditField(v string) string {
|
||||
value := strings.TrimSpace(v)
|
||||
if value == "" {
|
||||
return "-"
|
||||
}
|
||||
// 日志按 key=value 输出,替换空白避免解析歧义。
|
||||
value = strings.ReplaceAll(value, "\n", "_")
|
||||
value = strings.ReplaceAll(value, "\r", "_")
|
||||
value = strings.ReplaceAll(value, "\t", "_")
|
||||
value = strings.ReplaceAll(value, " ", "_")
|
||||
return value
|
||||
}
|
||||
|
||||
func resetIdempotencyMetricsForTest() {
|
||||
defaultIdempotencyMetrics.claimTotal.Store(0)
|
||||
defaultIdempotencyMetrics.replayTotal.Store(0)
|
||||
defaultIdempotencyMetrics.conflictTotal.Store(0)
|
||||
defaultIdempotencyMetrics.retryBackoffTotal.Store(0)
|
||||
defaultIdempotencyMetrics.processingDurationCount.Store(0)
|
||||
defaultIdempotencyMetrics.processingDurationMicros.Store(0)
|
||||
defaultIdempotencyMetrics.storeUnavailableTotal.Store(0)
|
||||
}
|
||||
|
||||
func builderWriteString(builder *strings.Builder, value string) {
|
||||
_, _ = builder.WriteString(value)
|
||||
}
|
||||
|
||||
func builderWriteByte(builder *strings.Builder, value byte) {
|
||||
_ = builder.WriteByte(value)
|
||||
}
|
||||
805
backend/internal/service/idempotency_test.go
Normal file
805
backend/internal/service/idempotency_test.go
Normal file
@@ -0,0 +1,805 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type inMemoryIdempotencyRepo struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*IdempotencyRecord
|
||||
}
|
||||
|
||||
func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo {
|
||||
return &inMemoryIdempotencyRepo{
|
||||
nextID: 1,
|
||||
data: make(map[string]*IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) key(scope, hash string) string {
|
||||
return scope + "|" + hash
|
||||
}
|
||||
|
||||
func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
rec := cloneRecord(record)
|
||||
rec.ID = r.nextID
|
||||
rec.CreatedAt = time.Now()
|
||||
rec.UpdatedAt = rec.CreatedAt
|
||||
r.nextID++
|
||||
r.data[k] = rec
|
||||
record.ID = rec.ID
|
||||
record.CreatedAt = rec.CreatedAt
|
||||
record.UpdatedAt = rec.UpdatedAt
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return cloneRecord(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
rec.UpdatedAt = time.Now()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
rec.ErrorReason = nil
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
return nil
|
||||
}
|
||||
return errors.New("record not found")
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.UpdatedAt = time.Now()
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return errors.New("record not found")
|
||||
}
|
||||
|
||||
func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var deleted int64
|
||||
for k, rec := range r.data {
|
||||
if !rec.ExpiresAt.After(now) {
|
||||
delete(r.data, k)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_RequireKey(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.ObserveOnly = false
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "admin:1",
|
||||
RequireKey: true,
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
execCount := 0
|
||||
exec := func(ctx context.Context) (any, error) {
|
||||
execCount++
|
||||
return map[string]any{"count": execCount}, nil
|
||||
}
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}
|
||||
|
||||
first, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.False(t, first.Replayed)
|
||||
|
||||
second, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.True(t, second.Replayed)
|
||||
require.Equal(t, 1, execCount, "second request should replay without executing business logic")
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ClaimTotal)
|
||||
require.Equal(t, uint64(1), metrics.ReplayTotal)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.expired",
|
||||
Method: "POST",
|
||||
Route: "/test/expired",
|
||||
ActorScope: "user:99",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "expired-case",
|
||||
Payload: map[string]any{"k": "v"},
|
||||
}
|
||||
|
||||
execCount := 0
|
||||
exec := func(ctx context.Context) (any, error) {
|
||||
execCount++
|
||||
return map[string]any{"count": execCount}, nil
|
||||
}
|
||||
|
||||
first, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, first)
|
||||
require.False(t, first.Replayed)
|
||||
require.Equal(t, 1, execCount)
|
||||
|
||||
keyHash := HashIdempotencyKey(opts.IdempotencyKey)
|
||||
repo.mu.Lock()
|
||||
existing := repo.data[repo.key(opts.Scope, keyHash)]
|
||||
require.NotNil(t, existing)
|
||||
existing.ExpiresAt = time.Now().Add(-time.Second)
|
||||
repo.mu.Unlock()
|
||||
|
||||
second, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, second)
|
||||
require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again")
|
||||
require.Equal(t, 2, execCount)
|
||||
|
||||
third, err := coordinator.Execute(context.Background(), opts, exec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, third)
|
||||
require.True(t, third.Replayed)
|
||||
payload, ok := third.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(2), payload["count"])
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2))
|
||||
require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-2",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-2",
|
||||
Payload: map[string]any{"a": 2},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict))
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ConflictTotal)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.FailedRetryBackoff = 2 * time.Second
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope",
|
||||
Method: "POST",
|
||||
Route: "/test",
|
||||
ActorScope: "user:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-3",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error")
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff))
|
||||
require.Greater(t, RetryAfterSecondsFromError(err), 0)
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2))
|
||||
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
|
||||
require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
coordinator := NewIdempotencyCoordinator(repo, cfg)
|
||||
|
||||
opts := IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.concurrent",
|
||||
Method: "POST",
|
||||
Route: "/test/concurrent",
|
||||
ActorScope: "user:7",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "concurrent-case",
|
||||
Payload: map[string]any{"v": 1},
|
||||
}
|
||||
|
||||
var execCount int32
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
atomic.AddInt32(&execCount, 1)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
|
||||
atomic.AddInt32(&execCount, 1)
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, replayed.Replayed)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once")
|
||||
|
||||
metrics := GetIdempotencyMetricsSnapshot()
|
||||
require.Equal(t, uint64(1), metrics.ClaimTotal)
|
||||
require.Equal(t, uint64(1), metrics.ReplayTotal)
|
||||
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
|
||||
}
|
||||
|
||||
type failingIdempotencyRepo struct{}
|
||||
|
||||
func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) {
|
||||
resetIdempotencyMetricsForTest()
|
||||
coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig())
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "test.scope.unavailable",
|
||||
Method: "POST",
|
||||
Route: "/test/unavailable",
|
||||
ActorScope: "admin:1",
|
||||
RequireKey: true,
|
||||
IdempotencyKey: "case-unavailable",
|
||||
Payload: map[string]any{"v": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1))
|
||||
}
|
||||
|
||||
func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
require.Nil(t, DefaultIdempotencyCoordinator())
|
||||
require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL())
|
||||
require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL())
|
||||
|
||||
coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
DefaultTTL: 2 * time.Hour,
|
||||
SystemOperationTTL: 15 * time.Minute,
|
||||
ProcessingTimeout: 10 * time.Second,
|
||||
FailedRetryBackoff: 3 * time.Second,
|
||||
ObserveOnly: false,
|
||||
})
|
||||
SetDefaultIdempotencyCoordinator(coordinator)
|
||||
t.Cleanup(func() {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
require.Same(t, coordinator, DefaultIdempotencyCoordinator())
|
||||
require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL())
|
||||
require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL())
|
||||
}
|
||||
|
||||
func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) {
|
||||
key, err := NormalizeIdempotencyKey(" abc-123 ")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "abc-123", key)
|
||||
|
||||
key, err = NormalizeIdempotencyKey("")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", key)
|
||||
|
||||
_, err = NormalizeIdempotencyKey(string(make([]byte, 129)))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = NormalizeIdempotencyKey("bad\nkey")
|
||||
require.Error(t, err)
|
||||
|
||||
fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, fp1)
|
||||
fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fp1, fp2)
|
||||
|
||||
_, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestRetryAfterSecondsFromErrorBranches(t *testing.T) {
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(nil))
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain")))
|
||||
|
||||
err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"})
|
||||
require.Equal(t, 12, RetryAfterSecondsFromError(err))
|
||||
|
||||
err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"})
|
||||
require.Equal(t, 0, RetryAfterSecondsFromError(err))
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, nil)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err))
|
||||
|
||||
called := 0
|
||||
result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
RequireKey: true,
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
called++
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, called)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.Replayed)
|
||||
}
|
||||
|
||||
type noIDOwnerRepo struct{}
|
||||
|
||||
func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil }
|
||||
func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil }
|
||||
|
||||
func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) {
|
||||
cfg := DefaultIdempotencyConfig()
|
||||
coordinator := NewIdempotencyCoordinator(nil, cfg)
|
||||
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg)
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
IdempotencyKey: "k2",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err))
|
||||
|
||||
coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg)
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-no-id",
|
||||
IdempotencyKey: "k3",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
type conflictBranchRepo struct {
|
||||
existing *IdempotencyRecord
|
||||
tryReclaimErr error
|
||||
tryReclaimOK bool
|
||||
}
|
||||
|
||||
func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
return cloneRecord(r.existing), nil
|
||||
}
|
||||
func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
if r.tryReclaimErr != nil {
|
||||
return false, r.tryReclaimErr
|
||||
}
|
||||
return r.tryReclaimOK, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) {
|
||||
now := time.Now()
|
||||
fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1})
|
||||
require.NoError(t, err)
|
||||
badBody := "{bad-json"
|
||||
repo := &conflictBranchRepo{
|
||||
existing: &IdempotencyRecord{
|
||||
ID: 1,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: IdempotencyStatusSucceeded,
|
||||
ResponseBody: &badBody,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 2,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: "unknown",
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 3,
|
||||
Scope: "scope",
|
||||
IdempotencyKeyHash: HashIdempotencyKey("k"),
|
||||
RequestFingerprint: fp,
|
||||
Status: IdempotencyStatusFailedRetryable,
|
||||
LockedUntil: ptrTime(now.Add(-time.Second)),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
repo.tryReclaimErr = errors.New("reclaim down")
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.tryReclaimErr = nil
|
||||
repo.tryReclaimOK = false
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope",
|
||||
IdempotencyKey: "k",
|
||||
Method: "POST",
|
||||
Route: "/x",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
type markBehaviorRepo struct {
|
||||
inMemoryIdempotencyRepo
|
||||
failMarkSucceeded bool
|
||||
failMarkFailed bool
|
||||
}
|
||||
|
||||
func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
if r.failMarkSucceeded {
|
||||
return errors.New("mark succeeded failed")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt)
|
||||
}
|
||||
|
||||
func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
if r.failMarkFailed {
|
||||
return errors.New("mark failed retryable failed")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt)
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) {
|
||||
repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()}
|
||||
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
|
||||
|
||||
repo.failMarkSucceeded = true
|
||||
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-success",
|
||||
IdempotencyKey: "k1",
|
||||
Method: "POST",
|
||||
Route: "/ok",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.failMarkSucceeded = false
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-marshal",
|
||||
IdempotencyKey: "k2",
|
||||
Method: "POST",
|
||||
Route: "/bad",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return map[string]any{"bad": make(chan int)}, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.failMarkFailed = true
|
||||
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
|
||||
Scope: "scope-fail",
|
||||
IdempotencyKey: "k3",
|
||||
Method: "POST",
|
||||
Route: "/fail",
|
||||
ActorScope: "u:1",
|
||||
Payload: map[string]any{"a": 1},
|
||||
}, func(ctx context.Context) (any, error) {
|
||||
return nil, errors.New("plain failure")
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "plain failure", err.Error())
|
||||
}
|
||||
|
||||
func TestIdempotencyCoordinator_HelperBranches(t *testing.T) {
|
||||
c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
DefaultTTL: time.Hour,
|
||||
SystemOperationTTL: time.Hour,
|
||||
ProcessingTimeout: time.Second,
|
||||
FailedRetryBackoff: time.Second,
|
||||
MaxStoredResponseLen: 12,
|
||||
ObserveOnly: false,
|
||||
})
|
||||
|
||||
// conflictWithRetryAfter without locked_until should return base error.
|
||||
base := ErrIdempotencyInProgress
|
||||
err := c.conflictWithRetryAfter(base, nil, time.Now())
|
||||
require.Equal(t, infraerrors.Code(base), infraerrors.Code(err))
|
||||
|
||||
// marshalStoredResponse should truncate.
|
||||
body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, body, "...(truncated)")
|
||||
|
||||
// decodeStoredResponse empty and invalid json.
|
||||
out, err := c.decodeStoredResponse(nil)
|
||||
require.NoError(t, err)
|
||||
_, ok := out.(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
invalid := "{invalid"
|
||||
_, err = c.decodeStoredResponse(&invalid)
|
||||
require.Error(t, err)
|
||||
}
|
||||
389
backend/internal/service/subscription_assign_idempotency_test.go
Normal file
389
backend/internal/service/subscription_assign_idempotency_test.go
Normal file
@@ -0,0 +1,389 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type groupRepoNoop struct{}
|
||||
|
||||
func (groupRepoNoop) Create(context.Context, *Group) error { panic("unexpected Create call") }
|
||||
func (groupRepoNoop) GetByID(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
func (groupRepoNoop) GetByIDLite(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
func (groupRepoNoop) Update(context.Context, *Group) error { panic("unexpected Update call") }
|
||||
func (groupRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
|
||||
func (groupRepoNoop) DeleteCascade(context.Context, int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
func (groupRepoNoop) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
func (groupRepoNoop) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
func (groupRepoNoop) ListActive(context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
func (groupRepoNoop) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
func (groupRepoNoop) BindAccountsToGroup(context.Context, int64, []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
func (groupRepoNoop) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
|
||||
panic("unexpected UpdateSortOrders call")
|
||||
}
|
||||
|
||||
type subscriptionGroupRepoStub struct {
|
||||
groupRepoNoop
|
||||
group *Group
|
||||
}
|
||||
|
||||
func (s *subscriptionGroupRepoStub) GetByID(context.Context, int64) (*Group, error) {
|
||||
return s.group, nil
|
||||
}
|
||||
|
||||
type userSubRepoNoop struct{}
|
||||
|
||||
func (userSubRepoNoop) Create(context.Context, *UserSubscription) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
func (userSubRepoNoop) GetByID(context.Context, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
func (userSubRepoNoop) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
panic("unexpected GetActiveByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) Update(context.Context, *UserSubscription) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
func (userSubRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
|
||||
func (userSubRepoNoop) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
panic("unexpected ListActiveByUserID call")
|
||||
}
|
||||
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
|
||||
panic("unexpected ExistsByUserIDAndGroupID call")
|
||||
}
|
||||
func (userSubRepoNoop) ExtendExpiry(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ExtendExpiry call")
|
||||
}
|
||||
func (userSubRepoNoop) UpdateStatus(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateStatus call")
|
||||
}
|
||||
func (userSubRepoNoop) UpdateNotes(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateNotes call")
|
||||
}
|
||||
func (userSubRepoNoop) ActivateWindows(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ActivateWindows call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetDailyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetDailyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetWeeklyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetWeeklyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) ResetMonthlyUsage(context.Context, int64, time.Time) error {
|
||||
panic("unexpected ResetMonthlyUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) IncrementUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected IncrementUsage call")
|
||||
}
|
||||
func (userSubRepoNoop) BatchUpdateExpiredStatus(context.Context) (int64, error) {
|
||||
panic("unexpected BatchUpdateExpiredStatus call")
|
||||
}
|
||||
|
||||
type subscriptionUserSubRepoStub struct {
|
||||
userSubRepoNoop
|
||||
|
||||
nextID int64
|
||||
byID map[int64]*UserSubscription
|
||||
byUserGroup map[string]*UserSubscription
|
||||
createCalls int
|
||||
}
|
||||
|
||||
func newSubscriptionUserSubRepoStub() *subscriptionUserSubRepoStub {
|
||||
return &subscriptionUserSubRepoStub{
|
||||
nextID: 1,
|
||||
byID: make(map[int64]*UserSubscription),
|
||||
byUserGroup: make(map[string]*UserSubscription),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) key(userID, groupID int64) string {
|
||||
return strconvFormatInt(userID) + ":" + strconvFormatInt(groupID)
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) seed(sub *UserSubscription) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
cp := *sub
|
||||
if cp.ID == 0 {
|
||||
cp.ID = s.nextID
|
||||
s.nextID++
|
||||
}
|
||||
s.byID[cp.ID] = &cp
|
||||
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) ExistsByUserIDAndGroupID(_ context.Context, userID, groupID int64) (bool, error) {
|
||||
_, ok := s.byUserGroup[s.key(userID, groupID)]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) GetByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
sub := s.byUserGroup[s.key(userID, groupID)]
|
||||
if sub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) Create(_ context.Context, sub *UserSubscription) error {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
s.createCalls++
|
||||
cp := *sub
|
||||
if cp.ID == 0 {
|
||||
cp.ID = s.nextID
|
||||
s.nextID++
|
||||
}
|
||||
sub.ID = cp.ID
|
||||
s.byID[cp.ID] = &cp
|
||||
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
|
||||
sub := s.byID[id]
|
||||
if sub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 10,
|
||||
UserID: 1001,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "init",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 1001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "init",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), sub.ID)
|
||||
require.Equal(t, 0, subRepo.createCalls, "reuse should not create new subscription")
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionConflictWhenSemanticsMismatch(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 11,
|
||||
UserID: 2001,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "old-note",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 2001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "new-note",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_ASSIGN_CONFLICT", infraerrorsReason(err))
|
||||
require.Equal(t, 0, subRepo.createCalls, "conflict should not create or mutate existing subscription")
|
||||
}
|
||||
|
||||
func TestBulkAssignSubscriptionCreatedReusedAndConflict(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
// user 1: 语义一致,可 reused
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 21,
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "same-note",
|
||||
})
|
||||
// user 3: 语义冲突(有效期不一致),应 failed
|
||||
subRepo.seed(&UserSubscription{
|
||||
ID: 23,
|
||||
UserID: 3,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 60),
|
||||
Notes: "same-note",
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
result, err := svc.BulkAssignSubscription(context.Background(), &BulkAssignSubscriptionInput{
|
||||
UserIDs: []int64{1, 2, 3},
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
AssignedBy: 9,
|
||||
Notes: "same-note",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.SuccessCount)
|
||||
require.Equal(t, 1, result.CreatedCount)
|
||||
require.Equal(t, 1, result.ReusedCount)
|
||||
require.Equal(t, 1, result.FailedCount)
|
||||
require.Equal(t, "reused", result.Statuses[1])
|
||||
require.Equal(t, "created", result.Statuses[2])
|
||||
require.Equal(t, "failed", result.Statuses[3])
|
||||
require.Equal(t, 1, subRepo.createCalls)
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionKeepsWorkingWhenIdempotencyStoreUnavailable(t *testing.T) {
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
SetDefaultIdempotencyCoordinator(NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 9001,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "new",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sub)
|
||||
require.Equal(t, 1, subRepo.createCalls, "semantic idempotent endpoint should not depend on idempotency store availability")
|
||||
}
|
||||
|
||||
func TestNormalizeAssignValidityDays(t *testing.T) {
|
||||
require.Equal(t, 30, normalizeAssignValidityDays(0))
|
||||
require.Equal(t, 30, normalizeAssignValidityDays(-5))
|
||||
require.Equal(t, MaxValidityDays, normalizeAssignValidityDays(MaxValidityDays+100))
|
||||
require.Equal(t, 7, normalizeAssignValidityDays(7))
|
||||
}
|
||||
|
||||
func TestDetectAssignSemanticConflictCases(t *testing.T) {
|
||||
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
|
||||
base := &UserSubscription{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
StartsAt: start,
|
||||
ExpiresAt: start.AddDate(0, 0, 30),
|
||||
Notes: "same",
|
||||
}
|
||||
|
||||
reason, conflict := detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "same",
|
||||
})
|
||||
require.False(t, conflict)
|
||||
require.Equal(t, "", reason)
|
||||
|
||||
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 60,
|
||||
Notes: "same",
|
||||
})
|
||||
require.True(t, conflict)
|
||||
require.Equal(t, "validity_days_mismatch", reason)
|
||||
|
||||
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
Notes: "other",
|
||||
})
|
||||
require.True(t, conflict)
|
||||
require.Equal(t, "notes_mismatch", reason)
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionGroupTypeValidation(t *testing.T) {
|
||||
groupRepo := &subscriptionGroupRepoStub{
|
||||
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeStandard},
|
||||
}
|
||||
subRepo := newSubscriptionUserSubRepoStub()
|
||||
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
|
||||
|
||||
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
|
||||
UserID: 1,
|
||||
GroupID: 1,
|
||||
ValidityDays: 30,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrGroupNotSubscriptionType), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func strconvFormatInt(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
func infraerrorsReason(err error) string {
|
||||
return infraerrors.Reason(err)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -24,16 +25,17 @@ var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
|
||||
const MaxValidityDays = 36500
|
||||
|
||||
var (
|
||||
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
|
||||
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
|
||||
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
|
||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
|
||||
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
|
||||
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
|
||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
@@ -150,40 +152,10 @@ type AssignSubscriptionInput struct {
|
||||
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
sub, _, err := s.assignSubscriptionWithReuse(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrSubscriptionAlreadyExists
|
||||
}
|
||||
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
@@ -363,9 +335,12 @@ type BulkAssignSubscriptionInput struct {
|
||||
// BulkAssignResult 批量分配结果
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int
|
||||
CreatedCount int
|
||||
ReusedCount int
|
||||
FailedCount int
|
||||
Subscriptions []UserSubscription
|
||||
Errors []string
|
||||
Statuses map[int64]string
|
||||
}
|
||||
|
||||
// BulkAssignSubscription 批量分配订阅
|
||||
@@ -373,10 +348,11 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
result := &BulkAssignResult{
|
||||
Subscriptions: make([]UserSubscription, 0),
|
||||
Errors: make([]string, 0),
|
||||
Statuses: make(map[int64]string),
|
||||
}
|
||||
|
||||
for _, userID := range input.UserIDs {
|
||||
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
|
||||
sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: input.GroupID,
|
||||
ValidityDays: input.ValidityDays,
|
||||
@@ -386,15 +362,105 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
if err != nil {
|
||||
result.FailedCount++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
|
||||
result.Statuses[userID] = "failed"
|
||||
} else {
|
||||
result.SuccessCount++
|
||||
result.Subscriptions = append(result.Subscriptions, *sub)
|
||||
if reused {
|
||||
result.ReusedCount++
|
||||
result.Statuses[userID] = "reused"
|
||||
} else {
|
||||
result.CreatedCount++
|
||||
result.Statuses[userID] = "created"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, false, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if exists {
|
||||
sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if getErr != nil {
|
||||
return nil, false, getErr
|
||||
}
|
||||
if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict {
|
||||
return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{
|
||||
"conflict_reason": conflictReason,
|
||||
})
|
||||
}
|
||||
return sub, true, nil
|
||||
}
|
||||
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, false, nil
|
||||
}
|
||||
|
||||
func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) {
|
||||
if existing == nil || input == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalizedDays := normalizeAssignValidityDays(input.ValidityDays)
|
||||
if !existing.StartsAt.IsZero() {
|
||||
expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays)
|
||||
if expectedExpiresAt.After(MaxExpiresAt) {
|
||||
expectedExpiresAt = MaxExpiresAt
|
||||
}
|
||||
if !existing.ExpiresAt.Equal(expectedExpiresAt) {
|
||||
return "validity_days_mismatch", true
|
||||
}
|
||||
}
|
||||
|
||||
existingNotes := strings.TrimSpace(existing.Notes)
|
||||
inputNotes := strings.TrimSpace(input.Notes)
|
||||
if existingNotes != inputNotes {
|
||||
return "notes_mismatch", true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func normalizeAssignValidityDays(days int) int {
|
||||
if days <= 0 {
|
||||
days = 30
|
||||
}
|
||||
if days > MaxValidityDays {
|
||||
days = MaxValidityDays
|
||||
}
|
||||
return days
|
||||
}
|
||||
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
|
||||
214
backend/internal/service/system_operation_lock_service.go
Normal file
214
backend/internal/service/system_operation_lock_service.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
systemOperationLockScope = "admin.system.operations.global_lock"
|
||||
systemOperationLockKey = "global-system-operation-lock"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSystemOperationBusy = infraerrors.Conflict("SYSTEM_OPERATION_BUSY", "another system operation is in progress")
|
||||
)
|
||||
|
||||
type SystemOperationLock struct {
|
||||
recordID int64
|
||||
operationID string
|
||||
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (l *SystemOperationLock) OperationID() string {
|
||||
if l == nil {
|
||||
return ""
|
||||
}
|
||||
return l.operationID
|
||||
}
|
||||
|
||||
type SystemOperationLockService struct {
|
||||
repo IdempotencyRepository
|
||||
|
||||
lease time.Duration
|
||||
renewInterval time.Duration
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewSystemOperationLockService(repo IdempotencyRepository, cfg IdempotencyConfig) *SystemOperationLockService {
|
||||
lease := cfg.ProcessingTimeout
|
||||
if lease <= 0 {
|
||||
lease = 30 * time.Second
|
||||
}
|
||||
renewInterval := lease / 3
|
||||
if renewInterval < time.Second {
|
||||
renewInterval = time.Second
|
||||
}
|
||||
ttl := cfg.SystemOperationTTL
|
||||
if ttl <= 0 {
|
||||
ttl = time.Hour
|
||||
}
|
||||
|
||||
return &SystemOperationLockService{
|
||||
repo: repo,
|
||||
lease: lease,
|
||||
renewInterval: renewInterval,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) Acquire(ctx context.Context, operationID string) (*SystemOperationLock, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if operationID == "" {
|
||||
return nil, infraerrors.BadRequest("SYSTEM_OPERATION_ID_REQUIRED", "operation id is required")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(s.ttl)
|
||||
lockedUntil := now.Add(s.lease)
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
|
||||
record := &IdempotencyRecord{
|
||||
Scope: systemOperationLockScope,
|
||||
IdempotencyKeyHash: keyHash,
|
||||
RequestFingerprint: operationID,
|
||||
Status: IdempotencyStatusProcessing,
|
||||
LockedUntil: &lockedUntil,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
owner, err := s.repo.CreateProcessing(ctx, record)
|
||||
if err != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
|
||||
}
|
||||
if !owner {
|
||||
existing, getErr := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
|
||||
if getErr != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
if existing.Status == IdempotencyStatusProcessing && existing.LockedUntil != nil && existing.LockedUntil.After(now) {
|
||||
return nil, s.busyError(existing.RequestFingerprint, existing.LockedUntil, now)
|
||||
}
|
||||
reclaimed, reclaimErr := s.repo.TryReclaim(
|
||||
ctx,
|
||||
existing.ID,
|
||||
existing.Status,
|
||||
now,
|
||||
lockedUntil,
|
||||
expiresAt,
|
||||
)
|
||||
if reclaimErr != nil {
|
||||
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
|
||||
}
|
||||
if !reclaimed {
|
||||
latest, _ := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
|
||||
if latest != nil {
|
||||
return nil, s.busyError(latest.RequestFingerprint, latest.LockedUntil, now)
|
||||
}
|
||||
return nil, ErrSystemOperationBusy
|
||||
}
|
||||
record.ID = existing.ID
|
||||
}
|
||||
|
||||
if record.ID == 0 {
|
||||
return nil, ErrIdempotencyStoreUnavail
|
||||
}
|
||||
|
||||
lock := &SystemOperationLock{
|
||||
recordID: record.ID,
|
||||
operationID: operationID,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go s.renewLoop(lock)
|
||||
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) Release(ctx context.Context, lock *SystemOperationLock, succeeded bool, failureReason string) error {
|
||||
if s == nil || s.repo == nil || lock == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lock.stopOnce.Do(func() {
|
||||
close(lock.stopCh)
|
||||
})
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(s.ttl)
|
||||
if succeeded {
|
||||
responseBody := fmt.Sprintf(`{"operation_id":"%s","released":true}`, lock.operationID)
|
||||
return s.repo.MarkSucceeded(ctx, lock.recordID, 200, responseBody, expiresAt)
|
||||
}
|
||||
|
||||
reason := failureReason
|
||||
if reason == "" {
|
||||
reason = "SYSTEM_OPERATION_FAILED"
|
||||
}
|
||||
return s.repo.MarkFailedRetryable(ctx, lock.recordID, reason, time.Now(), expiresAt)
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) renewLoop(lock *SystemOperationLock) {
|
||||
ticker := time.NewTicker(s.renewInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ok, err := s.repo.ExtendProcessingLock(
|
||||
ctx,
|
||||
lock.recordID,
|
||||
lock.operationID,
|
||||
now.Add(s.lease),
|
||||
now.Add(s.ttl),
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew failed operation_id=%s err=%v", lock.operationID, err)
|
||||
// 瞬时故障不应导致续租协程退出,下一轮继续尝试续租。
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew stopped operation_id=%s reason=ownership_lost", lock.operationID)
|
||||
return
|
||||
}
|
||||
case <-lock.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SystemOperationLockService) busyError(operationID string, lockedUntil *time.Time, now time.Time) error {
|
||||
metadata := make(map[string]string)
|
||||
if operationID != "" {
|
||||
metadata["operation_id"] = operationID
|
||||
}
|
||||
if lockedUntil != nil {
|
||||
sec := int(lockedUntil.Sub(now).Seconds())
|
||||
if sec <= 0 {
|
||||
sec = 1
|
||||
}
|
||||
metadata["retry_after"] = strconv.Itoa(sec)
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return ErrSystemOperationBusy
|
||||
}
|
||||
return ErrSystemOperationBusy.WithMetadata(metadata)
|
||||
}
|
||||
305
backend/internal/service/system_operation_lock_service_test.go
Normal file
305
backend/internal/service/system_operation_lock_service_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSystemOperationLockService_AcquireBusyAndRelease(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
_, err = svc.Acquire(context.Background(), "op-2")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Equal(t, "op-1", appErr.Metadata["operation_id"])
|
||||
require.NotEmpty(t, appErr.Metadata["retry_after"])
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RenewLease(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 1200 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock, err := svc.Acquire(context.Background(), "op-renew")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
defer func() {
|
||||
_ = svc.Release(context.Background(), lock, true, "")
|
||||
}()
|
||||
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, initial)
|
||||
require.NotNil(t, initial.LockedUntil)
|
||||
initialLockedUntil := *initial.LockedUntil
|
||||
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, updated)
|
||||
require.NotNil(t, updated.LockedUntil)
|
||||
require.True(t, updated.LockedUntil.After(initialLockedUntil), "locked_until should be renewed while lock is held")
|
||||
}
|
||||
|
||||
type flakySystemLockRenewRepo struct {
|
||||
*inMemoryIdempotencyRepo
|
||||
extendCalls int32
|
||||
}
|
||||
|
||||
func (r *flakySystemLockRenewRepo) ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
call := atomic.AddInt32(&r.extendCalls, 1)
|
||||
if call == 1 {
|
||||
return false, errors.New("transient extend failure")
|
||||
}
|
||||
return r.inMemoryIdempotencyRepo.ExtendProcessingLock(ctx, id, requestFingerprint, newLockedUntil, newExpiresAt)
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RenewLeaseContinuesAfterTransientFailure(t *testing.T) {
|
||||
repo := &flakySystemLockRenewRepo{inMemoryIdempotencyRepo: newInMemoryIdempotencyRepo()}
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 2400 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock, err := svc.Acquire(context.Background(), "op-renew-transient")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
defer func() {
|
||||
_ = svc.Release(context.Background(), lock, true, "")
|
||||
}()
|
||||
|
||||
keyHash := HashIdempotencyKey(systemOperationLockKey)
|
||||
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
require.NotNil(t, initial)
|
||||
require.NotNil(t, initial.LockedUntil)
|
||||
initialLockedUntil := *initial.LockedUntil
|
||||
|
||||
// 首次续租失败后,下一轮应继续尝试并成功更新锁过期时间。
|
||||
require.Eventually(t, func() bool {
|
||||
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
|
||||
if updated == nil || updated.LockedUntil == nil {
|
||||
return false
|
||||
}
|
||||
return atomic.LoadInt32(&repo.extendCalls) >= 2 && updated.LockedUntil.After(initialLockedUntil)
|
||||
}, 4*time.Second, 100*time.Millisecond, "renew loop should continue after transient error")
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_SameOperationIDRetryWhileRunning(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-same")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
_, err = svc.Acquire(context.Background(), "op-same")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Equal(t, "op-same", appErr.Metadata["operation_id"])
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-same")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_RecoverAfterLeaseExpired(t *testing.T) {
|
||||
repo := newInMemoryIdempotencyRepo()
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 5 * time.Second,
|
||||
ProcessingTimeout: 300 * time.Millisecond,
|
||||
})
|
||||
|
||||
lock1, err := svc.Acquire(context.Background(), "op-crashed")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock1)
|
||||
|
||||
// 模拟实例异常:停止续租,不调用 Release。
|
||||
lock1.stopOnce.Do(func() {
|
||||
close(lock1.stopCh)
|
||||
})
|
||||
|
||||
time.Sleep(450 * time.Millisecond)
|
||||
|
||||
lock2, err := svc.Acquire(context.Background(), "op-recovered")
|
||||
require.NoError(t, err, "expired lease should allow a new operation to reclaim lock")
|
||||
require.NotNil(t, lock2)
|
||||
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
|
||||
}
|
||||
|
||||
type systemLockRepoStub struct {
|
||||
createOwner bool
|
||||
createErr error
|
||||
existing *IdempotencyRecord
|
||||
getErr error
|
||||
reclaimOK bool
|
||||
reclaimErr error
|
||||
markSuccErr error
|
||||
markFailErr error
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
|
||||
if s.createErr != nil {
|
||||
return false, s.createErr
|
||||
}
|
||||
return s.createOwner, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return cloneRecord(s.existing), nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
if s.reclaimErr != nil {
|
||||
return false, s.reclaimErr
|
||||
}
|
||||
return s.reclaimOK, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return s.markSuccErr
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return s.markFailErr
|
||||
}
|
||||
|
||||
func (s *systemLockRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_InputAndStoreErrorBranches(t *testing.T) {
|
||||
var nilSvc *SystemOperationLockService
|
||||
_, err := nilSvc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
svc := &SystemOperationLockService{repo: nil}
|
||||
_, err = svc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
svc = NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
_, err = svc.Acquire(context.Background(), "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SYSTEM_OPERATION_ID_REQUIRED", infraerrors.Reason(err))
|
||||
|
||||
badStore := &systemLockRepoStub{createErr: errors.New("db down")}
|
||||
svc = NewSystemOperationLockService(badStore, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
_, err = svc.Acquire(context.Background(), "x")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_ExistingNilAndReclaimBranches(t *testing.T) {
|
||||
now := time.Now()
|
||||
repo := &systemLockRepoStub{
|
||||
createOwner: false,
|
||||
}
|
||||
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
_, err := svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.existing = &IdempotencyRecord{
|
||||
ID: 1,
|
||||
Scope: systemOperationLockScope,
|
||||
IdempotencyKeyHash: HashIdempotencyKey(systemOperationLockKey),
|
||||
RequestFingerprint: "other-op",
|
||||
Status: IdempotencyStatusFailedRetryable,
|
||||
LockedUntil: ptrTime(now.Add(-time.Second)),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
repo.reclaimErr = errors.New("reclaim failed")
|
||||
_, err = svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
|
||||
|
||||
repo.reclaimErr = nil
|
||||
repo.reclaimOK = false
|
||||
_, err = svc.Acquire(context.Background(), "op")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
}
|
||||
|
||||
func TestSystemOperationLockService_ReleaseBranchesAndOperationID(t *testing.T) {
|
||||
require.Equal(t, "", (*SystemOperationLock)(nil).OperationID())
|
||||
|
||||
svc := NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
lock, err := svc.Acquire(context.Background(), "op")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lock)
|
||||
|
||||
require.NoError(t, svc.Release(context.Background(), lock, false, ""))
|
||||
require.NoError(t, svc.Release(context.Background(), lock, true, ""))
|
||||
|
||||
repo := &systemLockRepoStub{
|
||||
createOwner: true,
|
||||
markSuccErr: errors.New("mark succeeded failed"),
|
||||
markFailErr: errors.New("mark failed failed"),
|
||||
}
|
||||
svc = NewSystemOperationLockService(repo, IdempotencyConfig{
|
||||
SystemOperationTTL: 10 * time.Second,
|
||||
ProcessingTimeout: 2 * time.Second,
|
||||
})
|
||||
lock = &SystemOperationLock{recordID: 1, operationID: "op2", stopCh: make(chan struct{})}
|
||||
require.Error(t, svc.Release(context.Background(), lock, true, ""))
|
||||
lock = &SystemOperationLock{recordID: 1, operationID: "op3", stopCh: make(chan struct{})}
|
||||
require.Error(t, svc.Release(context.Background(), lock, false, "BAD"))
|
||||
|
||||
var nilLockSvc *SystemOperationLockService
|
||||
require.NoError(t, nilLockSvc.Release(context.Background(), nil, true, ""))
|
||||
|
||||
err = svc.busyError("", nil, time.Now())
|
||||
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
|
||||
}
|
||||
@@ -320,6 +320,10 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
return err
|
||||
}
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
if status == UsageCleanupStatusCanceled {
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent hit: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
@@ -329,6 +333,11 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
}
|
||||
if !ok {
|
||||
// 状态可能并发改变
|
||||
currentStatus, getErr := s.repo.GetTaskStatus(ctx, taskID)
|
||||
if getErr == nil && currentStatus == UsageCleanupStatusCanceled {
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent race hit: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
|
||||
@@ -644,6 +644,23 @@ func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskAlreadyCanceledIsIdempotent(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusCanceled,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.cancelCalls, "already canceled should return success without extra cancel write")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
|
||||
shouldCancel := false
|
||||
repo := &cleanupRepoStub{
|
||||
|
||||
@@ -225,6 +225,45 @@ func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Confi
|
||||
return svc
|
||||
}
|
||||
|
||||
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
|
||||
idempotencyCfg := DefaultIdempotencyConfig()
|
||||
if cfg != nil {
|
||||
if cfg.Idempotency.DefaultTTLSeconds > 0 {
|
||||
idempotencyCfg.DefaultTTL = time.Duration(cfg.Idempotency.DefaultTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.SystemOperationTTLSeconds > 0 {
|
||||
idempotencyCfg.SystemOperationTTL = time.Duration(cfg.Idempotency.SystemOperationTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.ProcessingTimeoutSeconds > 0 {
|
||||
idempotencyCfg.ProcessingTimeout = time.Duration(cfg.Idempotency.ProcessingTimeoutSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.FailedRetryBackoffSeconds > 0 {
|
||||
idempotencyCfg.FailedRetryBackoff = time.Duration(cfg.Idempotency.FailedRetryBackoffSeconds) * time.Second
|
||||
}
|
||||
if cfg.Idempotency.MaxStoredResponseLen > 0 {
|
||||
idempotencyCfg.MaxStoredResponseLen = cfg.Idempotency.MaxStoredResponseLen
|
||||
}
|
||||
idempotencyCfg.ObserveOnly = cfg.Idempotency.ObserveOnly
|
||||
}
|
||||
return idempotencyCfg
|
||||
}
|
||||
|
||||
func ProvideIdempotencyCoordinator(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCoordinator {
|
||||
coordinator := NewIdempotencyCoordinator(repo, buildIdempotencyConfig(cfg))
|
||||
SetDefaultIdempotencyCoordinator(coordinator)
|
||||
return coordinator
|
||||
}
|
||||
|
||||
func ProvideSystemOperationLockService(repo IdempotencyRepository, cfg *config.Config) *SystemOperationLockService {
|
||||
return NewSystemOperationLockService(repo, buildIdempotencyConfig(cfg))
|
||||
}
|
||||
|
||||
func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
|
||||
svc := NewIdempotencyCleanupService(repo, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
||||
func ProvideOpsScheduledReportService(
|
||||
opsService *OpsService,
|
||||
@@ -318,4 +357,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
NewDigestSessionStore,
|
||||
ProvideIdempotencyCoordinator,
|
||||
ProvideSystemOperationLockService,
|
||||
ProvideIdempotencyCleanupService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user