feat(idempotency): 为关键写接口接入幂等并完善并发容错

This commit is contained in:
yangjianbo
2026-02-23 12:45:37 +08:00
parent 3b6584cc8d
commit 5fa45f3b8c
40 changed files with 4383 additions and 223 deletions

View File

@@ -35,6 +35,8 @@ var (
const (
apiKeyMaxErrorsPerHour = 20
apiKeyLastUsedMinTouch = 30 * time.Second
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
apiKeyLastUsedFailBackoff = 5 * time.Second
)
type APIKeyRepository interface {
@@ -129,7 +131,7 @@ type APIKeyService struct {
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> time.Time
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
}
@@ -574,7 +576,7 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
now := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch {
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
return nil
}
}
@@ -582,15 +584,16 @@ func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
latest := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch {
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
return nil, nil
}
}
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
return nil, fmt.Errorf("touch api key last used: %w", err)
}
s.lastUsedTouchL1.Store(keyID, latest)
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
return nil, nil
})
return err

View File

@@ -79,8 +79,27 @@ func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
require.ErrorContains(t, err, "touch api key last used")
require.Equal(t, []int64{123}, repo.touchedIDs)
_, ok := svc.lastUsedTouchL1.Load(int64(123))
require.False(t, ok, "failed touch should not update debounce cache")
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
require.True(t, ok, "failed touch should still update retry debounce cache")
_, isTime := cached.(time.Time)
require.True(t, isTime)
}
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
repo := &apiKeyRepoStub{
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
return errors.New("db write failed")
},
}
svc := &APIKeyService{apiKeyRepo: repo}
firstErr := svc.TouchLastUsed(context.Background(), 456)
require.Error(t, firstErr)
require.ErrorContains(t, firstErr, "touch api key last used")
secondErr := svc.TouchLastUsed(context.Background(), 456)
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
}
type touchSingleflightRepo struct {

View File

@@ -0,0 +1,471 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
)
const (
IdempotencyStatusProcessing = "processing"
IdempotencyStatusSucceeded = "succeeded"
IdempotencyStatusFailedRetryable = "failed_retryable"
)
var (
ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required")
ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid")
ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload")
ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing")
ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window")
ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable")
ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload")
)
type IdempotencyRecord struct {
ID int64
Scope string
IdempotencyKeyHash string
RequestFingerprint string
Status string
ResponseStatus *int
ResponseBody *string
ErrorReason *string
LockedUntil *time.Time
ExpiresAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
type IdempotencyRepository interface {
CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error)
GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error)
TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error)
ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error)
MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error
MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error
DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error)
}
type IdempotencyConfig struct {
DefaultTTL time.Duration
SystemOperationTTL time.Duration
ProcessingTimeout time.Duration
FailedRetryBackoff time.Duration
MaxStoredResponseLen int
ObserveOnly bool
}
func DefaultIdempotencyConfig() IdempotencyConfig {
return IdempotencyConfig{
DefaultTTL: 24 * time.Hour,
SystemOperationTTL: 1 * time.Hour,
ProcessingTimeout: 30 * time.Second,
FailedRetryBackoff: 5 * time.Second,
MaxStoredResponseLen: 64 * 1024,
ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断
}
}
type IdempotencyExecuteOptions struct {
Scope string
ActorScope string
Method string
Route string
IdempotencyKey string
Payload any
TTL time.Duration
RequireKey bool
}
type IdempotencyExecuteResult struct {
Data any
Replayed bool
}
type IdempotencyCoordinator struct {
repo IdempotencyRepository
cfg IdempotencyConfig
}
var (
defaultIdempotencyMu sync.RWMutex
defaultIdempotencySvc *IdempotencyCoordinator
)
func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) {
defaultIdempotencyMu.Lock()
defaultIdempotencySvc = svc
defaultIdempotencyMu.Unlock()
}
func DefaultIdempotencyCoordinator() *IdempotencyCoordinator {
defaultIdempotencyMu.RLock()
defer defaultIdempotencyMu.RUnlock()
return defaultIdempotencySvc
}
func DefaultWriteIdempotencyTTL() time.Duration {
defaultTTL := DefaultIdempotencyConfig().DefaultTTL
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 {
return coordinator.cfg.DefaultTTL
}
return defaultTTL
}
func DefaultSystemOperationIdempotencyTTL() time.Duration {
defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL
if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 {
return coordinator.cfg.SystemOperationTTL
}
return defaultTTL
}
func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator {
return &IdempotencyCoordinator{
repo: repo,
cfg: cfg,
}
}
func NormalizeIdempotencyKey(raw string) (string, error) {
key := strings.TrimSpace(raw)
if key == "" {
return "", nil
}
if len(key) > 128 {
return "", ErrIdempotencyKeyInvalid
}
for _, r := range key {
if r < 33 || r > 126 {
return "", ErrIdempotencyKeyInvalid
}
}
return key, nil
}
func HashIdempotencyKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) {
if method == "" {
method = "POST"
}
if route == "" {
route = "/"
}
if actorScope == "" {
actorScope = "anonymous"
}
raw, err := json.Marshal(payload)
if err != nil {
return "", ErrIdempotencyInvalidPayload.WithCause(err)
}
sum := sha256.Sum256([]byte(
strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw),
))
return hex.EncodeToString(sum[:]), nil
}
func RetryAfterSecondsFromError(err error) int {
appErr := new(infraerrors.ApplicationError)
if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil {
return 0
}
v := strings.TrimSpace(appErr.Metadata["retry_after"])
if v == "" {
return 0
}
seconds, convErr := strconv.Atoi(v)
if convErr != nil || seconds <= 0 {
return 0
}
return seconds
}
func (c *IdempotencyCoordinator) Execute(
ctx context.Context,
opts IdempotencyExecuteOptions,
execute func(context.Context) (any, error),
) (*IdempotencyExecuteResult, error) {
if execute == nil {
return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil")
}
key, err := NormalizeIdempotencyKey(opts.IdempotencyKey)
if err != nil {
return nil, err
}
if key == "" {
if opts.RequireKey && !c.cfg.ObserveOnly {
return nil, ErrIdempotencyKeyRequired
}
data, execErr := execute(ctx)
if execErr != nil {
return nil, execErr
}
return &IdempotencyExecuteResult{Data: data}, nil
}
if c.repo == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil")
return nil, ErrIdempotencyStoreUnavail
}
if opts.Scope == "" {
return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required")
}
fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload)
if err != nil {
return nil, err
}
ttl := opts.TTL
if ttl <= 0 {
ttl = c.cfg.DefaultTTL
}
now := time.Now()
expiresAt := now.Add(ttl)
lockedUntil := now.Add(c.cfg.ProcessingTimeout)
keyHash := HashIdempotencyKey(key)
record := &IdempotencyRecord{
Scope: opts.Scope,
IdempotencyKeyHash: keyHash,
RequestFingerprint: fingerprint,
Status: IdempotencyStatusProcessing,
LockedUntil: &lockedUntil,
ExpiresAt: expiresAt,
}
owner, err := c.repo.CreateProcessing(ctx, record)
if err != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "create_processing",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
}
if owner {
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{
"claim_mode": "new",
})
}
if !owner {
existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
if getErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "get_existing",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
}
if existing == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "missing_existing",
})
return nil, ErrIdempotencyStoreUnavail
}
if existing.RequestFingerprint != fingerprint {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
return nil, ErrIdempotencyKeyConflict
}
reclaimedByExpired := false
if !existing.ExpiresAt.After(now) {
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt)
if reclaimErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{
"operation": "try_reclaim_expired",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if taken {
reclaimedByExpired = true
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{
"claim_mode": "expired_reclaim",
})
record.ID = existing.ID
} else {
latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash)
if latestErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "get_existing_after_expired_reclaim",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr)
}
if latest == nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{
"operation": "missing_existing_after_expired_reclaim",
})
return nil, ErrIdempotencyStoreUnavail
}
if latest.RequestFingerprint != fingerprint {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil)
return nil, ErrIdempotencyKeyConflict
}
existing = latest
}
}
if !reclaimedByExpired {
switch existing.Status {
case IdempotencyStatusSucceeded:
data, parseErr := c.decodeStoredResponse(existing.ResponseBody)
if parseErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{
"operation": "decode_stored_response",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr)
}
recordIdempotencyReplay(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil)
return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil
case IdempotencyStatusProcessing:
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil)
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
case IdempotencyStatusFailedRetryable:
if existing.LockedUntil != nil && existing.LockedUntil.After(now) {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"})
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil)
return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now)
}
taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt)
if reclaimErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{
"operation": "try_reclaim",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if !taken {
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{
"conflict": "reclaim_race",
})
return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now)
}
recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{
"claim_mode": "reclaim",
})
record.ID = existing.ID
default:
recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"})
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{
"status": existing.Status,
})
return nil, ErrIdempotencyKeyConflict
}
}
}
if record.ID == 0 {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "record_id_missing",
})
return nil, ErrIdempotencyStoreUnavail
}
execStart := time.Now()
defer func() {
recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil)
}()
data, execErr := execute(ctx)
if execErr != nil {
backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff)
reason := infraerrors.Reason(execErr)
if reason == "" {
reason = "EXECUTION_FAILED"
}
recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil)
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{
"reason": reason,
})
if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "mark_failed_retryable",
})
}
return nil, execErr
}
storedBody, marshalErr := c.marshalStoredResponse(data)
if marshalErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "marshal_response",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr)
}
if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil {
RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error")
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{
"operation": "mark_succeeded",
})
return nil, ErrIdempotencyStoreUnavail.WithCause(markErr)
}
logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil)
return &IdempotencyExecuteResult{Data: data}, nil
}
func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error {
if lockedUntil == nil {
return base
}
sec := int(lockedUntil.Sub(now).Seconds())
if sec <= 0 {
sec = 1
}
return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)})
}
func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) {
raw, err := json.Marshal(data)
if err != nil {
return "", err
}
redacted := logredact.RedactText(string(raw))
if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen {
redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)"
}
return redacted, nil
}
func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) {
if stored == nil || strings.TrimSpace(*stored) == "" {
return map[string]any{}, nil
}
var out any
if err := json.Unmarshal([]byte(*stored), &out); err != nil {
return nil, fmt.Errorf("decode stored response: %w", err)
}
return out, nil
}

View File

@@ -0,0 +1,91 @@
package service
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。
type IdempotencyCleanupService struct {
repo IdempotencyRepository
interval time.Duration
batch int
startOnce sync.Once
stopOnce sync.Once
stopCh chan struct{}
}
func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
interval := 60 * time.Second
batch := 500
if cfg != nil {
if cfg.Idempotency.CleanupIntervalSeconds > 0 {
interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second
}
if cfg.Idempotency.CleanupBatchSize > 0 {
batch = cfg.Idempotency.CleanupBatchSize
}
}
return &IdempotencyCleanupService{
repo: repo,
interval: interval,
batch: batch,
stopCh: make(chan struct{}),
}
}
func (s *IdempotencyCleanupService) Start() {
if s == nil || s.repo == nil {
return
}
s.startOnce.Do(func() {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch)
go s.runLoop()
})
}
func (s *IdempotencyCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped")
})
}
func (s *IdempotencyCleanupService) runLoop() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
// 启动后先清理一轮,防止重启后积压。
s.cleanupOnce()
for {
select {
case <-ticker.C:
s.cleanupOnce()
case <-s.stopCh:
return
}
}
}
func (s *IdempotencyCleanupService) cleanupOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch)
if err != nil {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err)
return
}
if deleted > 0 {
logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted)
}
}

View File

@@ -0,0 +1,69 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type idempotencyCleanupRepoStub struct {
deleteCalls int
lastLimit int
deleteErr error
}
func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, nil
}
func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return nil
}
func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) {
r.deleteCalls++
r.lastLimit = limit
if r.deleteErr != nil {
return 0, r.deleteErr
}
return 1, nil
}
func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) {
repo := &idempotencyCleanupRepoStub{}
cfg := &config.Config{
Idempotency: config.IdempotencyConfig{
CleanupIntervalSeconds: 7,
CleanupBatchSize: 321,
},
}
svc := NewIdempotencyCleanupService(repo, cfg)
require.Equal(t, 7*time.Second, svc.interval)
require.Equal(t, 321, svc.batch)
}
func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) {
repo := &idempotencyCleanupRepoStub{}
svc := NewIdempotencyCleanupService(repo, &config.Config{
Idempotency: config.IdempotencyConfig{
CleanupBatchSize: 99,
},
})
svc.cleanupOnce()
require.Equal(t, 1, repo.deleteCalls)
require.Equal(t, 99, repo.lastLimit)
}

View File

@@ -0,0 +1,171 @@
package service
import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。
type IdempotencyMetricsSnapshot struct {
ClaimTotal uint64 `json:"claim_total"`
ReplayTotal uint64 `json:"replay_total"`
ConflictTotal uint64 `json:"conflict_total"`
RetryBackoffTotal uint64 `json:"retry_backoff_total"`
ProcessingDurationCount uint64 `json:"processing_duration_count"`
ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"`
StoreUnavailableTotal uint64 `json:"store_unavailable_total"`
}
type idempotencyMetrics struct {
claimTotal atomic.Uint64
replayTotal atomic.Uint64
conflictTotal atomic.Uint64
retryBackoffTotal atomic.Uint64
processingDurationCount atomic.Uint64
processingDurationMicros atomic.Uint64
storeUnavailableTotal atomic.Uint64
}
var defaultIdempotencyMetrics idempotencyMetrics
// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。
func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot {
totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load()
return IdempotencyMetricsSnapshot{
ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(),
ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(),
ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(),
RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(),
ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(),
ProcessingDurationTotalMs: float64(totalMicros) / 1000.0,
StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(),
}
}
func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.claimTotal.Add(1)
logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.replayTotal.Add(1)
logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.conflictTotal.Add(1)
logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) {
defaultIdempotencyMetrics.retryBackoffTotal.Add(1)
logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs)
}
func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) {
if duration < 0 {
duration = 0
}
defaultIdempotencyMetrics.processingDurationCount.Add(1)
defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds()))
logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs)
}
// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。
func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) {
defaultIdempotencyMetrics.storeUnavailableTotal.Add(1)
attrs := map[string]string{}
if strategy != "" {
attrs["strategy"] = strategy
}
logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs)
}
func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) {
var b strings.Builder
builderWriteString(&b, "[IdempotencyAudit]")
builderWriteString(&b, " endpoint=")
builderWriteString(&b, safeAuditField(endpoint))
builderWriteString(&b, " scope=")
builderWriteString(&b, safeAuditField(scope))
builderWriteString(&b, " key_hash=")
builderWriteString(&b, safeAuditField(keyHash))
builderWriteString(&b, " state_transition=")
builderWriteString(&b, safeAuditField(stateTransition))
builderWriteString(&b, " replayed=")
builderWriteString(&b, strconv.FormatBool(replayed))
if len(attrs) > 0 {
appendSortedAttrs(&b, attrs)
}
logger.LegacyPrintf("service.idempotency", "%s", b.String())
}
func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) {
var b strings.Builder
builderWriteString(&b, "[IdempotencyMetric]")
builderWriteString(&b, " name=")
builderWriteString(&b, safeAuditField(name))
builderWriteString(&b, " endpoint=")
builderWriteString(&b, safeAuditField(endpoint))
builderWriteString(&b, " scope=")
builderWriteString(&b, safeAuditField(scope))
builderWriteString(&b, " value=")
builderWriteString(&b, safeAuditField(value))
if len(attrs) > 0 {
appendSortedAttrs(&b, attrs)
}
logger.LegacyPrintf("service.idempotency", "%s", b.String())
}
func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) {
if len(attrs) == 0 {
return
}
keys := make([]string, 0, len(attrs))
for k := range attrs {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
builderWriteByte(builder, ' ')
builderWriteString(builder, k)
builderWriteByte(builder, '=')
builderWriteString(builder, safeAuditField(attrs[k]))
}
}
func safeAuditField(v string) string {
value := strings.TrimSpace(v)
if value == "" {
return "-"
}
// 日志按 key=value 输出,替换空白避免解析歧义。
value = strings.ReplaceAll(value, "\n", "_")
value = strings.ReplaceAll(value, "\r", "_")
value = strings.ReplaceAll(value, "\t", "_")
value = strings.ReplaceAll(value, " ", "_")
return value
}
func resetIdempotencyMetricsForTest() {
defaultIdempotencyMetrics.claimTotal.Store(0)
defaultIdempotencyMetrics.replayTotal.Store(0)
defaultIdempotencyMetrics.conflictTotal.Store(0)
defaultIdempotencyMetrics.retryBackoffTotal.Store(0)
defaultIdempotencyMetrics.processingDurationCount.Store(0)
defaultIdempotencyMetrics.processingDurationMicros.Store(0)
defaultIdempotencyMetrics.storeUnavailableTotal.Store(0)
}
func builderWriteString(builder *strings.Builder, value string) {
_, _ = builder.WriteString(value)
}
func builderWriteByte(builder *strings.Builder, value byte) {
_ = builder.WriteByte(value)
}

View File

@@ -0,0 +1,805 @@
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
type inMemoryIdempotencyRepo struct {
mu sync.Mutex
nextID int64
data map[string]*IdempotencyRecord
}
func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo {
return &inMemoryIdempotencyRepo{
nextID: 1,
data: make(map[string]*IdempotencyRecord),
}
}
func (r *inMemoryIdempotencyRepo) key(scope, hash string) string {
return scope + "|" + hash
}
func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
return &out
}
func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
rec := cloneRecord(record)
rec.ID = r.nextID
rec.CreatedAt = time.Now()
rec.UpdatedAt = rec.CreatedAt
r.nextID++
r.data[k] = rec
record.ID = rec.ID
record.CreatedAt = rec.CreatedAt
record.UpdatedAt = rec.UpdatedAt
return true, nil
}
func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return cloneRecord(r.data[r.key(scope, keyHash)]), nil
}
func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
rec.UpdatedAt = time.Now()
return true, nil
}
return false, nil
}
func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.UpdatedAt = time.Now()
return true, nil
}
return false, nil
}
func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.UpdatedAt = time.Now()
rec.ErrorReason = nil
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
return nil
}
return errors.New("record not found")
}
func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.UpdatedAt = time.Now()
rec.ErrorReason = &errorReason
return nil
}
return errors.New("record not found")
}
func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) {
r.mu.Lock()
defer r.mu.Unlock()
var deleted int64
for k, rec := range r.data {
if !rec.ExpiresAt.After(now) {
delete(r.data, k)
deleted++
}
}
return deleted, nil
}
func TestIdempotencyCoordinator_RequireKey(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.ObserveOnly = false
coordinator := NewIdempotencyCoordinator(repo, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "admin:1",
RequireKey: true,
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired))
}
func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(repo, cfg)
execCount := 0
exec := func(ctx context.Context) (any, error) {
execCount++
return map[string]any{"count": execCount}, nil
}
opts := IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-1",
Payload: map[string]any{"a": 1},
}
first, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.False(t, first.Replayed)
second, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.True(t, second.Replayed)
require.Equal(t, 1, execCount, "second request should replay without executing business logic")
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ClaimTotal)
require.Equal(t, uint64(1), metrics.ReplayTotal)
}
func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
opts := IdempotencyExecuteOptions{
Scope: "test.scope.expired",
Method: "POST",
Route: "/test/expired",
ActorScope: "user:99",
RequireKey: true,
IdempotencyKey: "expired-case",
Payload: map[string]any{"k": "v"},
}
execCount := 0
exec := func(ctx context.Context) (any, error) {
execCount++
return map[string]any{"count": execCount}, nil
}
first, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, first)
require.False(t, first.Replayed)
require.Equal(t, 1, execCount)
keyHash := HashIdempotencyKey(opts.IdempotencyKey)
repo.mu.Lock()
existing := repo.data[repo.key(opts.Scope, keyHash)]
require.NotNil(t, existing)
existing.ExpiresAt = time.Now().Add(-time.Second)
repo.mu.Unlock()
second, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, second)
require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again")
require.Equal(t, 2, execCount)
third, err := coordinator.Execute(context.Background(), opts, exec)
require.NoError(t, err)
require.NotNil(t, third)
require.True(t, third.Replayed)
payload, ok := third.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), payload["count"])
metrics := GetIdempotencyMetricsSnapshot()
require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2))
require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1))
}
func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(repo, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-2",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-2",
Payload: map[string]any{"a": 2},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict))
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ConflictTotal)
}
func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.FailedRetryBackoff = 2 * time.Second
coordinator := NewIdempotencyCoordinator(repo, cfg)
opts := IdempotencyExecuteOptions{
Scope: "test.scope",
Method: "POST",
Route: "/test",
ActorScope: "user:1",
RequireKey: true,
IdempotencyKey: "case-3",
Payload: map[string]any{"a": 1},
}
_, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error")
})
require.Error(t, err)
_, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff))
require.Greater(t, RetryAfterSecondsFromError(err), 0)
metrics := GetIdempotencyMetricsSnapshot()
require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2))
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1))
}
func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) {
resetIdempotencyMetricsForTest()
repo := newInMemoryIdempotencyRepo()
cfg := DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
coordinator := NewIdempotencyCoordinator(repo, cfg)
opts := IdempotencyExecuteOptions{
Scope: "test.scope.concurrent",
Method: "POST",
Route: "/test/concurrent",
ActorScope: "user:7",
RequireKey: true,
IdempotencyKey: "concurrent-case",
Payload: map[string]any{"v": 1},
}
var execCount int32
var wg sync.WaitGroup
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
atomic.AddInt32(&execCount, 1)
time.Sleep(80 * time.Millisecond)
return map[string]any{"ok": true}, nil
})
}()
}
wg.Wait()
replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) {
atomic.AddInt32(&execCount, 1)
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
require.True(t, replayed.Replayed)
require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once")
metrics := GetIdempotencyMetricsSnapshot()
require.Equal(t, uint64(1), metrics.ClaimTotal)
require.Equal(t, uint64(1), metrics.ReplayTotal)
require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1))
}
type failingIdempotencyRepo struct{}
func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) {
resetIdempotencyMetricsForTest()
coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig())
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "test.scope.unavailable",
Method: "POST",
Route: "/test/unavailable",
ActorScope: "admin:1",
RequireKey: true,
IdempotencyKey: "case-unavailable",
Payload: map[string]any{"v": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1))
}
func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) {
SetDefaultIdempotencyCoordinator(nil)
require.Nil(t, DefaultIdempotencyCoordinator())
require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL())
require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL())
coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
DefaultTTL: 2 * time.Hour,
SystemOperationTTL: 15 * time.Minute,
ProcessingTimeout: 10 * time.Second,
FailedRetryBackoff: 3 * time.Second,
ObserveOnly: false,
})
SetDefaultIdempotencyCoordinator(coordinator)
t.Cleanup(func() {
SetDefaultIdempotencyCoordinator(nil)
})
require.Same(t, coordinator, DefaultIdempotencyCoordinator())
require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL())
require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL())
}
func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) {
key, err := NormalizeIdempotencyKey(" abc-123 ")
require.NoError(t, err)
require.Equal(t, "abc-123", key)
key, err = NormalizeIdempotencyKey("")
require.NoError(t, err)
require.Equal(t, "", key)
_, err = NormalizeIdempotencyKey(string(make([]byte, 129)))
require.Error(t, err)
_, err = NormalizeIdempotencyKey("bad\nkey")
require.Error(t, err)
fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1})
require.NoError(t, err)
require.NotEmpty(t, fp1)
fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1})
require.NoError(t, err)
require.Equal(t, fp1, fp2)
_, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err))
}
func TestRetryAfterSecondsFromErrorBranches(t *testing.T) {
require.Equal(t, 0, RetryAfterSecondsFromError(nil))
require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain")))
err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"})
require.Equal(t, 12, RetryAfterSecondsFromError(err))
err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"})
require.Equal(t, 0, RetryAfterSecondsFromError(err))
}
func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Payload: map[string]any{"a": 1},
}, nil)
require.Error(t, err)
require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err))
called := 0
result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
RequireKey: true,
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
called++
return map[string]any{"ok": true}, nil
})
require.NoError(t, err)
require.Equal(t, 1, called)
require.NotNil(t, result)
require.False(t, result.Replayed)
}
type noIDOwnerRepo struct{}
func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return true, nil
}
func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return nil, nil
}
func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, nil
}
func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil }
func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil }
func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) {
cfg := DefaultIdempotencyConfig()
coordinator := NewIdempotencyCoordinator(nil, cfg)
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
IdempotencyKey: "k2",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err))
coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg)
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-no-id",
IdempotencyKey: "k3",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
}
type conflictBranchRepo struct {
existing *IdempotencyRecord
tryReclaimErr error
tryReclaimOK bool
}
func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
return false, nil
}
func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
return cloneRecord(r.existing), nil
}
func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
if r.tryReclaimErr != nil {
return false, r.tryReclaimErr
}
return r.tryReclaimOK, nil
}
func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, nil
}
func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return nil
}
func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return nil
}
func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, nil
}
func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) {
now := time.Now()
fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1})
require.NoError(t, err)
badBody := "{bad-json"
repo := &conflictBranchRepo{
existing: &IdempotencyRecord{
ID: 1,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: IdempotencyStatusSucceeded,
ResponseBody: &badBody,
ExpiresAt: now.Add(time.Hour),
},
}
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 2,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: "unknown",
ExpiresAt: now.Add(time.Hour),
}
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 3,
Scope: "scope",
IdempotencyKeyHash: HashIdempotencyKey("k"),
RequestFingerprint: fp,
Status: IdempotencyStatusFailedRetryable,
LockedUntil: ptrTime(now.Add(-time.Second)),
ExpiresAt: now.Add(time.Hour),
}
repo.tryReclaimErr = errors.New("reclaim down")
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.tryReclaimErr = nil
repo.tryReclaimOK = false
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope",
IdempotencyKey: "k",
Method: "POST",
Route: "/x",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err))
}
type markBehaviorRepo struct {
inMemoryIdempotencyRepo
failMarkSucceeded bool
failMarkFailed bool
}
func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
if r.failMarkSucceeded {
return errors.New("mark succeeded failed")
}
return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt)
}
func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
if r.failMarkFailed {
return errors.New("mark failed retryable failed")
}
return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt)
}
func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) {
repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()}
coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig())
repo.failMarkSucceeded = true
_, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-success",
IdempotencyKey: "k1",
Method: "POST",
Route: "/ok",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"ok": true}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.failMarkSucceeded = false
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-marshal",
IdempotencyKey: "k2",
Method: "POST",
Route: "/bad",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return map[string]any{"bad": make(chan int)}, nil
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.failMarkFailed = true
_, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{
Scope: "scope-fail",
IdempotencyKey: "k3",
Method: "POST",
Route: "/fail",
ActorScope: "u:1",
Payload: map[string]any{"a": 1},
}, func(ctx context.Context) (any, error) {
return nil, errors.New("plain failure")
})
require.Error(t, err)
require.Equal(t, "plain failure", err.Error())
}
func TestIdempotencyCoordinator_HelperBranches(t *testing.T) {
c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{
DefaultTTL: time.Hour,
SystemOperationTTL: time.Hour,
ProcessingTimeout: time.Second,
FailedRetryBackoff: time.Second,
MaxStoredResponseLen: 12,
ObserveOnly: false,
})
// conflictWithRetryAfter without locked_until should return base error.
base := ErrIdempotencyInProgress
err := c.conflictWithRetryAfter(base, nil, time.Now())
require.Equal(t, infraerrors.Code(base), infraerrors.Code(err))
// marshalStoredResponse should truncate.
body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"})
require.NoError(t, err)
require.Contains(t, body, "...(truncated)")
// decodeStoredResponse empty and invalid json.
out, err := c.decodeStoredResponse(nil)
require.NoError(t, err)
_, ok := out.(map[string]any)
require.True(t, ok)
invalid := "{invalid"
_, err = c.decodeStoredResponse(&invalid)
require.Error(t, err)
}

View File

@@ -0,0 +1,389 @@
package service
import (
"context"
"strconv"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type groupRepoNoop struct{}
func (groupRepoNoop) Create(context.Context, *Group) error { panic("unexpected Create call") }
func (groupRepoNoop) GetByID(context.Context, int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (groupRepoNoop) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (groupRepoNoop) Update(context.Context, *Group) error { panic("unexpected Update call") }
func (groupRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
func (groupRepoNoop) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (groupRepoNoop) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (groupRepoNoop) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (groupRepoNoop) ListActive(context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (groupRepoNoop) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (groupRepoNoop) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (groupRepoNoop) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected UpdateSortOrders call")
}
type subscriptionGroupRepoStub struct {
groupRepoNoop
group *Group
}
func (s *subscriptionGroupRepoStub) GetByID(context.Context, int64) (*Group, error) {
return s.group, nil
}
type userSubRepoNoop struct{}
func (userSubRepoNoop) Create(context.Context, *UserSubscription) error {
panic("unexpected Create call")
}
func (userSubRepoNoop) GetByID(context.Context, int64) (*UserSubscription, error) {
panic("unexpected GetByID call")
}
func (userSubRepoNoop) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
panic("unexpected GetByUserIDAndGroupID call")
}
func (userSubRepoNoop) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
panic("unexpected GetActiveByUserIDAndGroupID call")
}
func (userSubRepoNoop) Update(context.Context, *UserSubscription) error {
panic("unexpected Update call")
}
func (userSubRepoNoop) Delete(context.Context, int64) error { panic("unexpected Delete call") }
func (userSubRepoNoop) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
panic("unexpected ListByUserID call")
}
func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
panic("unexpected ListActiveByUserID call")
}
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
panic("unexpected ExistsByUserIDAndGroupID call")
}
func (userSubRepoNoop) ExtendExpiry(context.Context, int64, time.Time) error {
panic("unexpected ExtendExpiry call")
}
func (userSubRepoNoop) UpdateStatus(context.Context, int64, string) error {
panic("unexpected UpdateStatus call")
}
func (userSubRepoNoop) UpdateNotes(context.Context, int64, string) error {
panic("unexpected UpdateNotes call")
}
func (userSubRepoNoop) ActivateWindows(context.Context, int64, time.Time) error {
panic("unexpected ActivateWindows call")
}
func (userSubRepoNoop) ResetDailyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetDailyUsage call")
}
func (userSubRepoNoop) ResetWeeklyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetWeeklyUsage call")
}
func (userSubRepoNoop) ResetMonthlyUsage(context.Context, int64, time.Time) error {
panic("unexpected ResetMonthlyUsage call")
}
func (userSubRepoNoop) IncrementUsage(context.Context, int64, float64) error {
panic("unexpected IncrementUsage call")
}
func (userSubRepoNoop) BatchUpdateExpiredStatus(context.Context) (int64, error) {
panic("unexpected BatchUpdateExpiredStatus call")
}
type subscriptionUserSubRepoStub struct {
userSubRepoNoop
nextID int64
byID map[int64]*UserSubscription
byUserGroup map[string]*UserSubscription
createCalls int
}
func newSubscriptionUserSubRepoStub() *subscriptionUserSubRepoStub {
return &subscriptionUserSubRepoStub{
nextID: 1,
byID: make(map[int64]*UserSubscription),
byUserGroup: make(map[string]*UserSubscription),
}
}
func (s *subscriptionUserSubRepoStub) key(userID, groupID int64) string {
return strconvFormatInt(userID) + ":" + strconvFormatInt(groupID)
}
func (s *subscriptionUserSubRepoStub) seed(sub *UserSubscription) {
if sub == nil {
return
}
cp := *sub
if cp.ID == 0 {
cp.ID = s.nextID
s.nextID++
}
s.byID[cp.ID] = &cp
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
}
func (s *subscriptionUserSubRepoStub) ExistsByUserIDAndGroupID(_ context.Context, userID, groupID int64) (bool, error) {
_, ok := s.byUserGroup[s.key(userID, groupID)]
return ok, nil
}
func (s *subscriptionUserSubRepoStub) GetByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
sub := s.byUserGroup[s.key(userID, groupID)]
if sub == nil {
return nil, ErrSubscriptionNotFound
}
cp := *sub
return &cp, nil
}
func (s *subscriptionUserSubRepoStub) Create(_ context.Context, sub *UserSubscription) error {
if sub == nil {
return nil
}
s.createCalls++
cp := *sub
if cp.ID == 0 {
cp.ID = s.nextID
s.nextID++
}
sub.ID = cp.ID
s.byID[cp.ID] = &cp
s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
return nil
}
func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
sub := s.byID[id]
if sub == nil {
return nil, ErrSubscriptionNotFound
}
cp := *sub
return &cp, nil
}
func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
subRepo.seed(&UserSubscription{
ID: 10,
UserID: 1001,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "init",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 1001,
GroupID: 1,
ValidityDays: 30,
Notes: "init",
})
require.NoError(t, err)
require.Equal(t, int64(10), sub.ID)
require.Equal(t, 0, subRepo.createCalls, "reuse should not create new subscription")
}
func TestAssignSubscriptionConflictWhenSemanticsMismatch(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
subRepo.seed(&UserSubscription{
ID: 11,
UserID: 2001,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "old-note",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 2001,
GroupID: 1,
ValidityDays: 30,
Notes: "new-note",
})
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_ASSIGN_CONFLICT", infraerrorsReason(err))
require.Equal(t, 0, subRepo.createCalls, "conflict should not create or mutate existing subscription")
}
func TestBulkAssignSubscriptionCreatedReusedAndConflict(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
// user 1: 语义一致,可 reused
subRepo.seed(&UserSubscription{
ID: 21,
UserID: 1,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "same-note",
})
// user 3: 语义冲突(有效期不一致),应 failed
subRepo.seed(&UserSubscription{
ID: 23,
UserID: 3,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 60),
Notes: "same-note",
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
result, err := svc.BulkAssignSubscription(context.Background(), &BulkAssignSubscriptionInput{
UserIDs: []int64{1, 2, 3},
GroupID: 1,
ValidityDays: 30,
AssignedBy: 9,
Notes: "same-note",
})
require.NoError(t, err)
require.Equal(t, 2, result.SuccessCount)
require.Equal(t, 1, result.CreatedCount)
require.Equal(t, 1, result.ReusedCount)
require.Equal(t, 1, result.FailedCount)
require.Equal(t, "reused", result.Statuses[1])
require.Equal(t, "created", result.Statuses[2])
require.Equal(t, "failed", result.Statuses[3])
require.Equal(t, 1, subRepo.createCalls)
}
func TestAssignSubscriptionKeepsWorkingWhenIdempotencyStoreUnavailable(t *testing.T) {
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
}
subRepo := newSubscriptionUserSubRepoStub()
SetDefaultIdempotencyCoordinator(NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()))
t.Cleanup(func() {
SetDefaultIdempotencyCoordinator(nil)
})
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
sub, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 9001,
GroupID: 1,
ValidityDays: 30,
Notes: "new",
})
require.NoError(t, err)
require.NotNil(t, sub)
require.Equal(t, 1, subRepo.createCalls, "semantic idempotent endpoint should not depend on idempotency store availability")
}
func TestNormalizeAssignValidityDays(t *testing.T) {
require.Equal(t, 30, normalizeAssignValidityDays(0))
require.Equal(t, 30, normalizeAssignValidityDays(-5))
require.Equal(t, MaxValidityDays, normalizeAssignValidityDays(MaxValidityDays+100))
require.Equal(t, 7, normalizeAssignValidityDays(7))
}
func TestDetectAssignSemanticConflictCases(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
base := &UserSubscription{
UserID: 1,
GroupID: 1,
StartsAt: start,
ExpiresAt: start.AddDate(0, 0, 30),
Notes: "same",
}
reason, conflict := detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
Notes: "same",
})
require.False(t, conflict)
require.Equal(t, "", reason)
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 60,
Notes: "same",
})
require.True(t, conflict)
require.Equal(t, "validity_days_mismatch", reason)
reason, conflict = detectAssignSemanticConflict(base, &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
Notes: "other",
})
require.True(t, conflict)
require.Equal(t, "notes_mismatch", reason)
}
func TestAssignSubscriptionGroupTypeValidation(t *testing.T) {
groupRepo := &subscriptionGroupRepoStub{
group: &Group{ID: 1, SubscriptionType: SubscriptionTypeStandard},
}
subRepo := newSubscriptionUserSubRepoStub()
svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
_, err := svc.AssignSubscription(context.Background(), &AssignSubscriptionInput{
UserID: 1,
GroupID: 1,
ValidityDays: 30,
})
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrGroupNotSubscriptionType), infraerrors.Code(err))
}
func strconvFormatInt(v int64) string {
return strconv.FormatInt(v, 10)
}
func infraerrorsReason(err error) string {
return infraerrors.Reason(err)
}

View File

@@ -6,6 +6,7 @@ import (
"log"
"math/rand/v2"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -24,16 +25,17 @@ var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
const MaxValidityDays = 36500
var (
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
)
// SubscriptionService 订阅服务
@@ -150,40 +152,10 @@ type AssignSubscriptionInput struct {
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
sub, _, err := s.assignSubscriptionWithReuse(ctx, input)
if err != nil {
return nil, err
}
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil
}
@@ -363,9 +335,12 @@ type BulkAssignSubscriptionInput struct {
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
CreatedCount int
ReusedCount int
FailedCount int
Subscriptions []UserSubscription
Errors []string
Statuses map[int64]string
}
// BulkAssignSubscription 批量分配订阅
@@ -373,10 +348,11 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
result := &BulkAssignResult{
Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0),
Statuses: make(map[int64]string),
}
for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
@@ -386,15 +362,105 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
result.Statuses[userID] = "failed"
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
if reused {
result.ReusedCount++
result.Statuses[userID] = "reused"
} else {
result.CreatedCount++
result.Statuses[userID] = "created"
}
}
}
return result, nil
}
func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, false, err
}
if exists {
sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if getErr != nil {
return nil, false, getErr
}
if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict {
return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{
"conflict_reason": conflictReason,
})
}
return sub, true, nil
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil
}
func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) {
if existing == nil || input == nil {
return "", false
}
normalizedDays := normalizeAssignValidityDays(input.ValidityDays)
if !existing.StartsAt.IsZero() {
expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays)
if expectedExpiresAt.After(MaxExpiresAt) {
expectedExpiresAt = MaxExpiresAt
}
if !existing.ExpiresAt.Equal(expectedExpiresAt) {
return "validity_days_mismatch", true
}
}
existingNotes := strings.TrimSpace(existing.Notes)
inputNotes := strings.TrimSpace(input.Notes)
if existingNotes != inputNotes {
return "notes_mismatch", true
}
return "", false
}
func normalizeAssignValidityDays(days int) int {
if days <= 0 {
days = 30
}
if days > MaxValidityDays {
days = MaxValidityDays
}
return days
}
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存

View File

@@ -0,0 +1,214 @@
package service
import (
"context"
"fmt"
"strconv"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
systemOperationLockScope = "admin.system.operations.global_lock"
systemOperationLockKey = "global-system-operation-lock"
)
var (
ErrSystemOperationBusy = infraerrors.Conflict("SYSTEM_OPERATION_BUSY", "another system operation is in progress")
)
type SystemOperationLock struct {
recordID int64
operationID string
stopOnce sync.Once
stopCh chan struct{}
}
func (l *SystemOperationLock) OperationID() string {
if l == nil {
return ""
}
return l.operationID
}
type SystemOperationLockService struct {
repo IdempotencyRepository
lease time.Duration
renewInterval time.Duration
ttl time.Duration
}
func NewSystemOperationLockService(repo IdempotencyRepository, cfg IdempotencyConfig) *SystemOperationLockService {
lease := cfg.ProcessingTimeout
if lease <= 0 {
lease = 30 * time.Second
}
renewInterval := lease / 3
if renewInterval < time.Second {
renewInterval = time.Second
}
ttl := cfg.SystemOperationTTL
if ttl <= 0 {
ttl = time.Hour
}
return &SystemOperationLockService{
repo: repo,
lease: lease,
renewInterval: renewInterval,
ttl: ttl,
}
}
func (s *SystemOperationLockService) Acquire(ctx context.Context, operationID string) (*SystemOperationLock, error) {
if s == nil || s.repo == nil {
return nil, ErrIdempotencyStoreUnavail
}
if operationID == "" {
return nil, infraerrors.BadRequest("SYSTEM_OPERATION_ID_REQUIRED", "operation id is required")
}
now := time.Now()
expiresAt := now.Add(s.ttl)
lockedUntil := now.Add(s.lease)
keyHash := HashIdempotencyKey(systemOperationLockKey)
record := &IdempotencyRecord{
Scope: systemOperationLockScope,
IdempotencyKeyHash: keyHash,
RequestFingerprint: operationID,
Status: IdempotencyStatusProcessing,
LockedUntil: &lockedUntil,
ExpiresAt: expiresAt,
}
owner, err := s.repo.CreateProcessing(ctx, record)
if err != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(err)
}
if !owner {
existing, getErr := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
if getErr != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(getErr)
}
if existing == nil {
return nil, ErrIdempotencyStoreUnavail
}
if existing.Status == IdempotencyStatusProcessing && existing.LockedUntil != nil && existing.LockedUntil.After(now) {
return nil, s.busyError(existing.RequestFingerprint, existing.LockedUntil, now)
}
reclaimed, reclaimErr := s.repo.TryReclaim(
ctx,
existing.ID,
existing.Status,
now,
lockedUntil,
expiresAt,
)
if reclaimErr != nil {
return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr)
}
if !reclaimed {
latest, _ := s.repo.GetByScopeAndKeyHash(ctx, systemOperationLockScope, keyHash)
if latest != nil {
return nil, s.busyError(latest.RequestFingerprint, latest.LockedUntil, now)
}
return nil, ErrSystemOperationBusy
}
record.ID = existing.ID
}
if record.ID == 0 {
return nil, ErrIdempotencyStoreUnavail
}
lock := &SystemOperationLock{
recordID: record.ID,
operationID: operationID,
stopCh: make(chan struct{}),
}
go s.renewLoop(lock)
return lock, nil
}
func (s *SystemOperationLockService) Release(ctx context.Context, lock *SystemOperationLock, succeeded bool, failureReason string) error {
if s == nil || s.repo == nil || lock == nil {
return nil
}
lock.stopOnce.Do(func() {
close(lock.stopCh)
})
if ctx == nil {
ctx = context.Background()
}
expiresAt := time.Now().Add(s.ttl)
if succeeded {
responseBody := fmt.Sprintf(`{"operation_id":"%s","released":true}`, lock.operationID)
return s.repo.MarkSucceeded(ctx, lock.recordID, 200, responseBody, expiresAt)
}
reason := failureReason
if reason == "" {
reason = "SYSTEM_OPERATION_FAILED"
}
return s.repo.MarkFailedRetryable(ctx, lock.recordID, reason, time.Now(), expiresAt)
}
func (s *SystemOperationLockService) renewLoop(lock *SystemOperationLock) {
ticker := time.NewTicker(s.renewInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ok, err := s.repo.ExtendProcessingLock(
ctx,
lock.recordID,
lock.operationID,
now.Add(s.lease),
now.Add(s.ttl),
)
cancel()
if err != nil {
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew failed operation_id=%s err=%v", lock.operationID, err)
// 瞬时故障不应导致续租协程退出,下一轮继续尝试续租。
continue
}
if !ok {
logger.LegacyPrintf("service.system_operation_lock", "[SystemOperationLock] renew stopped operation_id=%s reason=ownership_lost", lock.operationID)
return
}
case <-lock.stopCh:
return
}
}
}
func (s *SystemOperationLockService) busyError(operationID string, lockedUntil *time.Time, now time.Time) error {
metadata := make(map[string]string)
if operationID != "" {
metadata["operation_id"] = operationID
}
if lockedUntil != nil {
sec := int(lockedUntil.Sub(now).Seconds())
if sec <= 0 {
sec = 1
}
metadata["retry_after"] = strconv.Itoa(sec)
}
if len(metadata) == 0 {
return ErrSystemOperationBusy
}
return ErrSystemOperationBusy.WithMetadata(metadata)
}

View File

@@ -0,0 +1,305 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestSystemOperationLockService_AcquireBusyAndRelease(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock1, err := svc.Acquire(context.Background(), "op-1")
require.NoError(t, err)
require.NotNil(t, lock1)
_, err = svc.Acquire(context.Background(), "op-2")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
appErr := infraerrors.FromError(err)
require.Equal(t, "op-1", appErr.Metadata["operation_id"])
require.NotEmpty(t, appErr.Metadata["retry_after"])
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
lock2, err := svc.Acquire(context.Background(), "op-2")
require.NoError(t, err)
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
func TestSystemOperationLockService_RenewLease(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 1200 * time.Millisecond,
})
lock, err := svc.Acquire(context.Background(), "op-renew")
require.NoError(t, err)
require.NotNil(t, lock)
defer func() {
_ = svc.Release(context.Background(), lock, true, "")
}()
keyHash := HashIdempotencyKey(systemOperationLockKey)
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, initial)
require.NotNil(t, initial.LockedUntil)
initialLockedUntil := *initial.LockedUntil
time.Sleep(1500 * time.Millisecond)
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, updated)
require.NotNil(t, updated.LockedUntil)
require.True(t, updated.LockedUntil.After(initialLockedUntil), "locked_until should be renewed while lock is held")
}
type flakySystemLockRenewRepo struct {
*inMemoryIdempotencyRepo
extendCalls int32
}
func (r *flakySystemLockRenewRepo) ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
call := atomic.AddInt32(&r.extendCalls, 1)
if call == 1 {
return false, errors.New("transient extend failure")
}
return r.inMemoryIdempotencyRepo.ExtendProcessingLock(ctx, id, requestFingerprint, newLockedUntil, newExpiresAt)
}
func TestSystemOperationLockService_RenewLeaseContinuesAfterTransientFailure(t *testing.T) {
repo := &flakySystemLockRenewRepo{inMemoryIdempotencyRepo: newInMemoryIdempotencyRepo()}
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 2400 * time.Millisecond,
})
lock, err := svc.Acquire(context.Background(), "op-renew-transient")
require.NoError(t, err)
require.NotNil(t, lock)
defer func() {
_ = svc.Release(context.Background(), lock, true, "")
}()
keyHash := HashIdempotencyKey(systemOperationLockKey)
initial, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
require.NotNil(t, initial)
require.NotNil(t, initial.LockedUntil)
initialLockedUntil := *initial.LockedUntil
// 首次续租失败后,下一轮应继续尝试并成功更新锁过期时间。
require.Eventually(t, func() bool {
updated, _ := repo.GetByScopeAndKeyHash(context.Background(), systemOperationLockScope, keyHash)
if updated == nil || updated.LockedUntil == nil {
return false
}
return atomic.LoadInt32(&repo.extendCalls) >= 2 && updated.LockedUntil.After(initialLockedUntil)
}, 4*time.Second, 100*time.Millisecond, "renew loop should continue after transient error")
}
func TestSystemOperationLockService_SameOperationIDRetryWhileRunning(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock1, err := svc.Acquire(context.Background(), "op-same")
require.NoError(t, err)
require.NotNil(t, lock1)
_, err = svc.Acquire(context.Background(), "op-same")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
appErr := infraerrors.FromError(err)
require.Equal(t, "op-same", appErr.Metadata["operation_id"])
require.NoError(t, svc.Release(context.Background(), lock1, true, ""))
lock2, err := svc.Acquire(context.Background(), "op-same")
require.NoError(t, err)
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
func TestSystemOperationLockService_RecoverAfterLeaseExpired(t *testing.T) {
repo := newInMemoryIdempotencyRepo()
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 5 * time.Second,
ProcessingTimeout: 300 * time.Millisecond,
})
lock1, err := svc.Acquire(context.Background(), "op-crashed")
require.NoError(t, err)
require.NotNil(t, lock1)
// 模拟实例异常:停止续租,不调用 Release。
lock1.stopOnce.Do(func() {
close(lock1.stopCh)
})
time.Sleep(450 * time.Millisecond)
lock2, err := svc.Acquire(context.Background(), "op-recovered")
require.NoError(t, err, "expired lease should allow a new operation to reclaim lock")
require.NotNil(t, lock2)
require.NoError(t, svc.Release(context.Background(), lock2, true, ""))
}
type systemLockRepoStub struct {
createOwner bool
createErr error
existing *IdempotencyRecord
getErr error
reclaimOK bool
reclaimErr error
markSuccErr error
markFailErr error
}
func (s *systemLockRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) {
if s.createErr != nil {
return false, s.createErr
}
return s.createOwner, nil
}
func (s *systemLockRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) {
if s.getErr != nil {
return nil, s.getErr
}
return cloneRecord(s.existing), nil
}
func (s *systemLockRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
if s.reclaimErr != nil {
return false, s.reclaimErr
}
return s.reclaimOK, nil
}
func (s *systemLockRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return true, nil
}
func (s *systemLockRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return s.markSuccErr
}
func (s *systemLockRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return s.markFailErr
}
func (s *systemLockRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, nil
}
func TestSystemOperationLockService_InputAndStoreErrorBranches(t *testing.T) {
var nilSvc *SystemOperationLockService
_, err := nilSvc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
svc := &SystemOperationLockService{repo: nil}
_, err = svc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
svc = NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err = svc.Acquire(context.Background(), "")
require.Error(t, err)
require.Equal(t, "SYSTEM_OPERATION_ID_REQUIRED", infraerrors.Reason(err))
badStore := &systemLockRepoStub{createErr: errors.New("db down")}
svc = NewSystemOperationLockService(badStore, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err = svc.Acquire(context.Background(), "x")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
}
func TestSystemOperationLockService_ExistingNilAndReclaimBranches(t *testing.T) {
now := time.Now()
repo := &systemLockRepoStub{
createOwner: false,
}
svc := NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
_, err := svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.existing = &IdempotencyRecord{
ID: 1,
Scope: systemOperationLockScope,
IdempotencyKeyHash: HashIdempotencyKey(systemOperationLockKey),
RequestFingerprint: "other-op",
Status: IdempotencyStatusFailedRetryable,
LockedUntil: ptrTime(now.Add(-time.Second)),
ExpiresAt: now.Add(time.Hour),
}
repo.reclaimErr = errors.New("reclaim failed")
_, err = svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err))
repo.reclaimErr = nil
repo.reclaimOK = false
_, err = svc.Acquire(context.Background(), "op")
require.Error(t, err)
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
}
func TestSystemOperationLockService_ReleaseBranchesAndOperationID(t *testing.T) {
require.Equal(t, "", (*SystemOperationLock)(nil).OperationID())
svc := NewSystemOperationLockService(newInMemoryIdempotencyRepo(), IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock, err := svc.Acquire(context.Background(), "op")
require.NoError(t, err)
require.NotNil(t, lock)
require.NoError(t, svc.Release(context.Background(), lock, false, ""))
require.NoError(t, svc.Release(context.Background(), lock, true, ""))
repo := &systemLockRepoStub{
createOwner: true,
markSuccErr: errors.New("mark succeeded failed"),
markFailErr: errors.New("mark failed failed"),
}
svc = NewSystemOperationLockService(repo, IdempotencyConfig{
SystemOperationTTL: 10 * time.Second,
ProcessingTimeout: 2 * time.Second,
})
lock = &SystemOperationLock{recordID: 1, operationID: "op2", stopCh: make(chan struct{})}
require.Error(t, svc.Release(context.Background(), lock, true, ""))
lock = &SystemOperationLock{recordID: 1, operationID: "op3", stopCh: make(chan struct{})}
require.Error(t, svc.Release(context.Background(), lock, false, "BAD"))
var nilLockSvc *SystemOperationLockService
require.NoError(t, nilLockSvc.Release(context.Background(), nil, true, ""))
err = svc.busyError("", nil, time.Now())
require.Equal(t, infraerrors.Code(ErrSystemOperationBusy), infraerrors.Code(err))
}

View File

@@ -320,6 +320,10 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
return err
}
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
if status == UsageCleanupStatusCanceled {
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent hit: task=%d operator=%d", taskID, canceledBy)
return nil
}
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
@@ -329,6 +333,11 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
}
if !ok {
// 状态可能并发改变
currentStatus, getErr := s.repo.GetTaskStatus(ctx, taskID)
if getErr == nil && currentStatus == UsageCleanupStatusCanceled {
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task idempotent race hit: task=%d operator=%d", taskID, canceledBy)
return nil
}
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)

View File

@@ -644,6 +644,23 @@ func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCancelTaskAlreadyCanceledIsIdempotent(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusCanceled,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.NoError(t, err)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.cancelCalls, "already canceled should return success without extra cancel write")
}
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
shouldCancel := false
repo := &cleanupRepoStub{

View File

@@ -225,6 +225,45 @@ func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Confi
return svc
}
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
idempotencyCfg := DefaultIdempotencyConfig()
if cfg != nil {
if cfg.Idempotency.DefaultTTLSeconds > 0 {
idempotencyCfg.DefaultTTL = time.Duration(cfg.Idempotency.DefaultTTLSeconds) * time.Second
}
if cfg.Idempotency.SystemOperationTTLSeconds > 0 {
idempotencyCfg.SystemOperationTTL = time.Duration(cfg.Idempotency.SystemOperationTTLSeconds) * time.Second
}
if cfg.Idempotency.ProcessingTimeoutSeconds > 0 {
idempotencyCfg.ProcessingTimeout = time.Duration(cfg.Idempotency.ProcessingTimeoutSeconds) * time.Second
}
if cfg.Idempotency.FailedRetryBackoffSeconds > 0 {
idempotencyCfg.FailedRetryBackoff = time.Duration(cfg.Idempotency.FailedRetryBackoffSeconds) * time.Second
}
if cfg.Idempotency.MaxStoredResponseLen > 0 {
idempotencyCfg.MaxStoredResponseLen = cfg.Idempotency.MaxStoredResponseLen
}
idempotencyCfg.ObserveOnly = cfg.Idempotency.ObserveOnly
}
return idempotencyCfg
}
func ProvideIdempotencyCoordinator(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCoordinator {
coordinator := NewIdempotencyCoordinator(repo, buildIdempotencyConfig(cfg))
SetDefaultIdempotencyCoordinator(coordinator)
return coordinator
}
func ProvideSystemOperationLockService(repo IdempotencyRepository, cfg *config.Config) *SystemOperationLockService {
return NewSystemOperationLockService(repo, buildIdempotencyConfig(cfg))
}
func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService {
svc := NewIdempotencyCleanupService(repo, cfg)
svc.Start()
return svc
}
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
func ProvideOpsScheduledReportService(
opsService *OpsService,
@@ -318,4 +357,7 @@ var ProviderSet = wire.NewSet(
NewTotpService,
NewErrorPassthroughService,
NewDigestSessionStore,
ProvideIdempotencyCoordinator,
ProvideSystemOperationLockService,
ProvideIdempotencyCleanupService,
)