338 lines
11 KiB
Go
338 lines
11 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"strings"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
)
|
||
|
||
type usageBillingRepository struct {
|
||
db *sql.DB
|
||
}
|
||
|
||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||
return &usageBillingRepository{db: sqlDB}
|
||
}
|
||
|
||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||
if cmd == nil {
|
||
return &service.UsageBillingApplyResult{}, nil
|
||
}
|
||
if r == nil || r.db == nil {
|
||
return nil, errors.New("usage billing repository db is nil")
|
||
}
|
||
|
||
cmd.Normalize()
|
||
if cmd.RequestID == "" {
|
||
return nil, service.ErrUsageBillingRequestIDRequired
|
||
}
|
||
|
||
tx, err := r.db.BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
if tx != nil {
|
||
_ = tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if !applied {
|
||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||
}
|
||
|
||
result := &service.UsageBillingApplyResult{Applied: true}
|
||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := tx.Commit(); err != nil {
|
||
return nil, err
|
||
}
|
||
tx = nil
|
||
return result, nil
|
||
}
|
||
|
||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||
var id int64
|
||
err := tx.QueryRowContext(ctx, `
|
||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||
VALUES ($1, $2, $3)
|
||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||
RETURNING id
|
||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||
if errors.Is(err, sql.ErrNoRows) {
|
||
var existingFingerprint string
|
||
if err := tx.QueryRowContext(ctx, `
|
||
SELECT request_fingerprint
|
||
FROM usage_billing_dedup
|
||
WHERE request_id = $1 AND api_key_id = $2
|
||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||
return false, err
|
||
}
|
||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||
return false, service.ErrUsageBillingRequestConflict
|
||
}
|
||
return false, nil
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
var archivedFingerprint string
|
||
err = tx.QueryRowContext(ctx, `
|
||
SELECT request_fingerprint
|
||
FROM usage_billing_dedup_archive
|
||
WHERE request_id = $1 AND api_key_id = $2
|
||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||
if err == nil {
|
||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||
return false, service.ErrUsageBillingRequestConflict
|
||
}
|
||
return false, nil
|
||
}
|
||
if !errors.Is(err, sql.ErrNoRows) {
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if cmd.BalanceCost > 0 {
|
||
newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
result.NewBalance = &newBalance
|
||
}
|
||
|
||
if cmd.APIKeyQuotaCost > 0 {
|
||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
result.APIKeyQuotaExhausted = exhausted
|
||
}
|
||
|
||
if cmd.APIKeyRateLimitCost > 0 {
|
||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||
quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
result.QuotaState = quotaState
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||
const updateSQL = `
|
||
UPDATE user_subscriptions us
|
||
SET
|
||
daily_usage_usd = us.daily_usage_usd + $1,
|
||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||
updated_at = NOW()
|
||
FROM groups g
|
||
WHERE us.id = $2
|
||
AND us.deleted_at IS NULL
|
||
AND us.group_id = g.id
|
||
AND g.deleted_at IS NULL
|
||
`
|
||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
affected, err := res.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if affected > 0 {
|
||
return nil
|
||
}
|
||
return service.ErrSubscriptionNotFound
|
||
}
|
||
|
||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
|
||
var newBalance float64
|
||
err := tx.QueryRowContext(ctx, `
|
||
UPDATE users
|
||
SET balance = balance - $1,
|
||
updated_at = NOW()
|
||
WHERE id = $2 AND deleted_at IS NULL
|
||
RETURNING balance
|
||
`, amount, userID).Scan(&newBalance)
|
||
if errors.Is(err, sql.ErrNoRows) {
|
||
return 0, service.ErrUserNotFound
|
||
}
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return newBalance, nil
|
||
}
|
||
|
||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||
var exhausted bool
|
||
err := tx.QueryRowContext(ctx, `
|
||
UPDATE api_keys
|
||
SET quota_used = quota_used + $1,
|
||
status = CASE
|
||
WHEN quota > 0
|
||
AND status = $3
|
||
AND quota_used < quota
|
||
AND quota_used + $1 >= quota
|
||
THEN $4
|
||
ELSE status
|
||
END,
|
||
updated_at = NOW()
|
||
WHERE id = $2 AND deleted_at IS NULL
|
||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||
if errors.Is(err, sql.ErrNoRows) {
|
||
return false, service.ErrAPIKeyNotFound
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return exhausted, nil
|
||
}
|
||
|
||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||
res, err := tx.ExecContext(ctx, `
|
||
UPDATE api_keys SET
|
||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||
updated_at = NOW()
|
||
WHERE id = $2 AND deleted_at IS NULL
|
||
`, cost, apiKeyID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
affected, err := res.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if affected == 0 {
|
||
return service.ErrAPIKeyNotFound
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
|
||
rows, err := tx.QueryContext(ctx,
|
||
`UPDATE accounts SET extra = (
|
||
COALESCE(extra, '{}'::jsonb)
|
||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||
jsonb_build_object(
|
||
'quota_daily_used',
|
||
CASE WHEN `+dailyExpiredExpr+`
|
||
THEN $1
|
||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||
'quota_daily_start',
|
||
CASE WHEN `+dailyExpiredExpr+`
|
||
THEN `+nowUTC+`
|
||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||
)
|
||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||
ELSE '{}'::jsonb END
|
||
ELSE '{}'::jsonb END
|
||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||
jsonb_build_object(
|
||
'quota_weekly_used',
|
||
CASE WHEN `+weeklyExpiredExpr+`
|
||
THEN $1
|
||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||
'quota_weekly_start',
|
||
CASE WHEN `+weeklyExpiredExpr+`
|
||
THEN `+nowUTC+`
|
||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||
)
|
||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||
ELSE '{}'::jsonb END
|
||
ELSE '{}'::jsonb END
|
||
), updated_at = NOW()
|
||
WHERE id = $2 AND deleted_at IS NULL
|
||
RETURNING
|
||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||
COALESCE((extra->>'quota_limit')::numeric, 0),
|
||
COALESCE((extra->>'quota_daily_used')::numeric, 0),
|
||
COALESCE((extra->>'quota_daily_limit')::numeric, 0),
|
||
COALESCE((extra->>'quota_weekly_used')::numeric, 0),
|
||
COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
|
||
amount, accountID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var state service.AccountQuotaState
|
||
if rows.Next() {
|
||
if err := rows.Scan(
|
||
&state.TotalUsed, &state.TotalLimit,
|
||
&state.DailyUsed, &state.DailyLimit,
|
||
&state.WeeklyUsed, &state.WeeklyLimit,
|
||
); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
} else {
|
||
if err := rows.Err(); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
_ = rows.Close()
|
||
return nil, service.ErrAccountNotFound
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
// 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
|
||
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
|
||
// "unexpected Parse response" 错误。
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
|
||
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
|
||
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
|
||
// 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
|
||
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
|
||
crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
|
||
crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
|
||
crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
|
||
if crossedTotal || crossedDaily || crossedWeekly {
|
||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||
return nil, err
|
||
}
|
||
}
|
||
return &state, nil
|
||
}
|