Merge remote-tracking branch 'upstream/main' into feat/channel-insights

# Conflicts:
#	backend/cmd/server/wire_gen.go
This commit is contained in:
erio
2026-04-23 22:30:45 +08:00
106 changed files with 5109 additions and 1427 deletions

View File

@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldSignupSource,
user.FieldLastLoginAt,
user.FieldLastActiveAt,
user.FieldRpmLimit,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
group.FieldRpmLimit,
)
}).
Only(ctx)
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged,
RPMLimit: u.RpmLimit,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const openAI403CounterPrefix = "openai_403_count:account:"
var openAI403CounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type openAI403CounterCache struct {
rdb *redis.Client
}
func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
return &openAI403CounterCache{rdb: rdb}
}
func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60
}
result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment openai 403 count: %w", err)
}
return result, nil
}
func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"errors"
"net/http"
"net/url"
"strings"
@@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
Post(s.tokenURL)
if err != nil {
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
return nil, newOpenAINoProxyHintError(err)
}
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
Post(s.tokenURL)
if err != nil {
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
return nil, newOpenAINoProxyHintError(err)
}
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
Timeout: 120 * time.Second,
})
}
func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool {
if strings.TrimSpace(proxyURL) != "" || err == nil {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
return !errors.Is(err, context.Canceled)
}
func newOpenAINoProxyHintError(cause error) error {
return infraerrors.New(
http.StatusBadGateway,
"OPENAI_OAUTH_PROXY_REQUIRED",
"OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.",
).WithCause(cause)
}

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})

View File

@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var state service.AccountQuotaState
if rows.Next() {
@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
&state.DailyUsed, &state.DailyLimit,
&state.WeeklyUsed, &state.WeeklyLimit,
); err != nil {
_ = rows.Close()
return nil, err
}
} else {
if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, err
}
_ = rows.Close()
return nil, service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, err
}
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
// 必须在执行下一条 SQL 前显式关闭 rowspq 驱动在同一连接上
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
// "unexpected Parse response" 错误。
if err := rows.Close(); err != nil {
return nil, err
}
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
// 对于日/周额度即使本次触发了周期重置pre=0、post=amount
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
if crossedTotal || crossedDaily || crossedWeekly {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return nil, err

View File

@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
t.Helper()
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-outbox-" + uuid.NewString(),
Name: "billing-outbox",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-outbox-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: extra,
})
return apiKey.ID, account.ID
}
outboxCountFor := func(t *testing.T, accountID int64) int {
t.Helper()
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx,
"SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
service.SchedulerOutboxEventAccountChanged, accountID,
).Scan(&count))
return count
}
t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
apiKeyID, accountID := newFixture(t, map[string]any{
"quota_daily_limit": 10.0,
})
// 第一次低于日限额:不应入队 outbox
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 4,
})
require.NoError(t, err)
require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
// 第二次跨越日限额:应入队一次 outbox
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 8,
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
// 再次递增(已超):不应重复入队
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 2,
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
})
t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
apiKeyID, accountID := newFixture(t, map[string]any{
"quota_weekly_limit": 10.0,
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 15, // 单次即跨越
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
})
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)

View File

@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户所有专属分组倍率
// GetByUserID 获取用户所有专属分组 rate_multiplier仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构map[userID]map[groupID]rate
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil
}
// GetByGroupID 获取指定分组下所有用户的专属倍率
// GetByGroupID 获取指定分组下所有用户的专属配置rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
var rate sql.NullFloat64
var rpm sql.NullInt32
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err
}
if rate.Valid {
v := rate.Float64
entry.RateMultiplier = &v
}
if rpm.Valid {
v := int(rpm.Int32)
entry.RPMOverride = &v
}
result = append(result, entry)
}
if err := rows.Err(); err != nil {
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplierNULL 返回 nil
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil {
return nil, err
}
return &rate, nil
if !rate.Valid {
return nil, nil
}
v := rate.Float64
return &v, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_overrideNULL 返回 nil
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rpm sql.NullInt32
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
if !rpm.Valid {
return nil, nil
}
v := int(rpm.Int32)
return &v, nil
}
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map清空该用户所有行的 rate_multiplier若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil清空对应行的 rate_multiplier保留 rpm_override
// - 值非 nilupsert rate_multiplier保留已有 rpm_override
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`, userID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
var clearGroupIDs []int64
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
clearGroupIDs = append(clearGroupIDs, groupID)
} else {
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
// 删除指定的记录
if len(toDelete) > 0 {
if len(clearGroupIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`, userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
userID, pq.Array(toDelete)); err != nil {
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
if len(upsertGroupIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override
// 语义:
// - 未出现在 entries 中的用户行rate_multiplier 归 NULL若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
keepUserIDs := make([]int64, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(entries) == 0 {
return nil
}
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier
// 语义:
// - 未出现的用户行rpm_override 归 NULL若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
keepUserIDs := make([]int64, 0, len(entries))
var clearUserIDs []int64
upsertUserIDs := make([]int64, 0, len(entries))
upsertValues := make([]int32, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
if e.RPMOverride == nil {
clearUserIDs = append(clearUserIDs, e.UserID)
} else {
upsertUserIDs = append(upsertUserIDs, e.UserID)
upsertValues = append(upsertValues, int32(*e.RPMOverride))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 显式 clear 的行。
if len(clearUserIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`, groupID, pq.Array(clearUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(upsertUserIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
if err != nil {
return err
}
}
return nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID)
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
// DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err

View File

@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
SetRpmLimit(userIn.RPMLimit).
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
SetTotalRecharged(userIn.TotalRecharged).
SetRpmLimit(userIn.RPMLimit)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}

View File

@@ -0,0 +1,108 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源rdb.Time()Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE兼容 Redis Cluster。
// - TTL120s覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义超限判断由调用方billing_cache_service.checkRPM与 RPMLimit 比较完成。
const (
userGroupRPMKeyPrefix = "rpm:ug:"
userRPMKeyPrefix = "rpm:u:"
userRPMKeyTTL = 120 * time.Second
)
type userRPMCacheImpl struct {
rdb *redis.Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
return &userRPMCacheImpl{rdb: rdb}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
t, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, fmt.Errorf("redis TIME: %w", err)
}
return t.Unix() / 60, nil
}
// atomicIncr 原子 INCR+EXPIRE。
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
pipe := c.rdb.TxPipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, userRPMKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("user rpm increment: %w", err)
}
return int(incr.Val()), nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
return c.atomicIncr(ctx, key)
}
// IncrementUserRPM 递增用户分钟计数。
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
return c.atomicIncr(ctx, key)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM只读
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user group rpm get: %w", err)
}
return val, nil
}
// GetUserRPM 获取用户当前分钟已用 RPM只读
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user rpm get: %w", err)
}
return val, nil
}

View File

@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
NewOpenAI403CounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
NewUserRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,