Merge remote-tracking branch 'upstream/main' into feat/channel-insights
# Conflicts: # backend/cmd/server/wire_gen.go
This commit is contained in:
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
user.FieldSignupSource,
|
||||
user.FieldLastLoginAt,
|
||||
user.FieldLastActiveAt,
|
||||
user.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldMessagesDispatchModelConfig,
|
||||
group.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||
TotalRecharged: u.TotalRecharged,
|
||||
RPMLimit: u.RpmLimit,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||
RPMLimit: g.RpmLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
51
backend/internal/repository/openai_403_counter_cache.go
Normal file
51
backend/internal/repository/openai_403_counter_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const openAI403CounterPrefix = "openai_403_count:account:"
|
||||
|
||||
var openAI403CounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type openAI403CounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
|
||||
return &openAI403CounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
|
||||
|
||||
ttlSeconds := windowMinutes * 60
|
||||
if ttlSeconds < 60 {
|
||||
ttlSeconds = 60
|
||||
}
|
||||
|
||||
result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment openai 403 count: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
|
||||
return nil, newOpenAINoProxyHintError(err)
|
||||
}
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
|
||||
return nil, newOpenAINoProxyHintError(err)
|
||||
}
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool {
|
||||
if strings.TrimSpace(proxyURL) != "" || err == nil {
|
||||
return false
|
||||
}
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
return !errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func newOpenAINoProxyHintError(cause error) error {
|
||||
return infraerrors.New(
|
||||
http.StatusBadGateway,
|
||||
"OPENAI_OAUTH_PROXY_REQUIRED",
|
||||
"OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.",
|
||||
).WithCause(cause)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
|
||||
require.Error(s.T(), err)
|
||||
require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
|
||||
require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
|
||||
@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var state service.AccountQuotaState
|
||||
if rows.Next() {
|
||||
@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|
||||
&state.DailyUsed, &state.DailyLimit,
|
||||
&state.WeeklyUsed, &state.WeeklyLimit,
|
||||
); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
return nil, service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
|
||||
// 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
|
||||
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
|
||||
// "unexpected Parse response" 错误。
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
|
||||
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
|
||||
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
|
||||
// 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
|
||||
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
|
||||
crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
|
||||
crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
|
||||
crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
|
||||
if crossedTotal || crossedDaily || crossedWeekly {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return nil, err
|
||||
|
||||
@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
|
||||
t.Helper()
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-outbox-" + uuid.NewString(),
|
||||
Name: "billing-outbox",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-outbox-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: extra,
|
||||
})
|
||||
return apiKey.ID, account.ID
|
||||
}
|
||||
|
||||
outboxCountFor := func(t *testing.T, accountID int64) int {
|
||||
t.Helper()
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
|
||||
service.SchedulerOutboxEventAccountChanged, accountID,
|
||||
).Scan(&count))
|
||||
return count
|
||||
}
|
||||
|
||||
t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
|
||||
apiKeyID, accountID := newFixture(t, map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
})
|
||||
// 第一次低于日限额:不应入队 outbox
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
|
||||
|
||||
// 第二次跨越日限额:应入队一次 outbox
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
|
||||
|
||||
// 再次递增(已超):不应重复入队
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
|
||||
})
|
||||
|
||||
t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
|
||||
apiKeyID, accountID := newFixture(t, map[string]any{
|
||||
"quota_weekly_limit": 10.0,
|
||||
})
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 15, // 单次即跨越
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
||||
// 返回结构:map[userID]map[groupID]rate
|
||||
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||||
result := make(map[int64]map[int64]float64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT user_id, group_id, rate_multiplier
|
||||
FROM user_group_rate_multipliers
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||
WHERE ugr.group_id = $1
|
||||
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||
var rate sql.NullFloat64
|
||||
var rpm sql.NullInt32
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rate.Valid {
|
||||
v := rate.Float64
|
||||
entry.RateMultiplier = &v
|
||||
}
|
||||
if rpm.Valid {
|
||||
v := int(rpm.Int32)
|
||||
entry.RPMOverride = &v
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
var rate sql.NullFloat64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
if !rate.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := rate.Float64
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
|
||||
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rpm sql.NullInt32
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rpm.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := int(rpm.Int32)
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
|
||||
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
|
||||
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
var clearGroupIDs []int64
|
||||
upsertGroupIDs := make([]int64, 0, len(rates))
|
||||
upsertRates := make([]float64, 0, len(rates))
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
clearGroupIDs = append(clearGroupIDs, groupID)
|
||||
} else {
|
||||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||||
upsertRates = append(upsertRates, *rate)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
if len(toDelete) > 0 {
|
||||
if len(clearGroupIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1 AND group_id = ANY($2)
|
||||
`, userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
||||
userID, pq.Array(toDelete)); err != nil {
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
if len(upsertGroupIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
SELECT
|
||||
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
|
||||
// 语义:
|
||||
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:upsert rate_multiplier。
|
||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rate_multiplier。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, len(entries))
|
||||
rates := make([]float64, len(entries))
|
||||
for i, e := range entries {
|
||||
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
|
||||
// 语义:
|
||||
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
|
||||
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
var clearUserIDs []int64
|
||||
upsertUserIDs := make([]int64, 0, len(entries))
|
||||
upsertValues := make([]int32, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
if e.RPMOverride == nil {
|
||||
clearUserIDs = append(clearUserIDs, e.UserID)
|
||||
} else {
|
||||
upsertUserIDs = append(upsertUserIDs, e.UserID)
|
||||
upsertValues = append(upsertValues, int32(*e.RPMOverride))
|
||||
}
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rpm_override。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 显式 clear 的行。
|
||||
if len(clearUserIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id = ANY($2)
|
||||
`, groupID, pq.Array(clearUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(upsertUserIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
|
||||
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
|
||||
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
|
||||
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属条目
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
// DeleteByUserID 删除指定用户的所有专属条目
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
|
||||
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
SetRpmLimit(userIn.RPMLimit).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||
SetTotalRecharged(userIn.TotalRecharged)
|
||||
SetTotalRecharged(userIn.TotalRecharged).
|
||||
SetRpmLimit(userIn.RPMLimit)
|
||||
if userIn.SignupSource != "" {
|
||||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||||
}
|
||||
|
||||
108
backend/internal/repository/user_rpm_cache.go
Normal file
108
backend/internal/repository/user_rpm_cache.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 用户/分组级 RPM 计数器 Redis 实现。
|
||||
//
|
||||
// 设计说明:
|
||||
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
|
||||
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
|
||||
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
|
||||
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
|
||||
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
|
||||
const (
|
||||
userGroupRPMKeyPrefix = "rpm:ug:"
|
||||
userRPMKeyPrefix = "rpm:u:"
|
||||
|
||||
userRPMKeyTTL = 120 * time.Second
|
||||
)
|
||||
|
||||
type userRPMCacheImpl struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
|
||||
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
|
||||
return &userRPMCacheImpl{rdb: rdb}
|
||||
}
|
||||
|
||||
// minuteTS 获取当前 Redis 服务端分钟时间戳。
|
||||
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
|
||||
t, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
return t.Unix() / 60, nil
|
||||
}
|
||||
|
||||
// atomicIncr 原子 INCR+EXPIRE。
|
||||
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
|
||||
pipe := c.rdb.TxPipeline()
|
||||
incr := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, userRPMKeyTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("user rpm increment: %w", err)
|
||||
}
|
||||
return int(incr.Val()), nil
|
||||
}
|
||||
|
||||
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// IncrementUserRPM 递增用户分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user group rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
NewTimeoutCounterCache,
|
||||
NewOpenAI403CounterCache,
|
||||
NewInternal500CounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
NewUserRPMCache,
|
||||
NewUserMsgQueueCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
|
||||
Reference in New Issue
Block a user